You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

475 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. "github.com/tendermint/tendermint/crypto/secp256k1"
  20. cmn "github.com/tendermint/tendermint/libs/common"
  21. "github.com/tendermint/tendermint/libs/rand"
  22. )
  23. type kvstoreConn struct {
  24. *io.PipeReader
  25. *io.PipeWriter
  26. }
  27. func (drw kvstoreConn) Close() (err error) {
  28. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  29. err1 := drw.PipeReader.Close()
  30. if err2 != nil {
  31. return err
  32. }
  33. return err1
  34. }
  35. // Each returned ReadWriteCloser is akin to a net.Connection
  36. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  37. barReader, fooWriter := io.Pipe()
  38. fooReader, barWriter := io.Pipe()
  39. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  40. }
  41. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  42. var fooConn, barConn = makeKVStoreConnPair()
  43. var fooPrvKey = ed25519.GenPrivKey()
  44. var fooPubKey = fooPrvKey.PubKey()
  45. var barPrvKey = ed25519.GenPrivKey()
  46. var barPubKey = barPrvKey.PubKey()
  47. // Make connections from both sides in parallel.
  48. var trs, ok = cmn.Parallel(
  49. func(_ int) (val interface{}, abort bool, err error) {
  50. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  51. if err != nil {
  52. tb.Errorf("failed to establish SecretConnection for foo: %v", err)
  53. return nil, true, err
  54. }
  55. remotePubBytes := fooSecConn.RemotePubKey()
  56. if !remotePubBytes.Equals(barPubKey) {
  57. err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  58. barPubKey, fooSecConn.RemotePubKey())
  59. tb.Error(err)
  60. return nil, false, err
  61. }
  62. return nil, false, nil
  63. },
  64. func(_ int) (val interface{}, abort bool, err error) {
  65. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  66. if barSecConn == nil {
  67. tb.Errorf("failed to establish SecretConnection for bar: %v", err)
  68. return nil, true, err
  69. }
  70. remotePubBytes := barSecConn.RemotePubKey()
  71. if !remotePubBytes.Equals(fooPubKey) {
  72. err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  73. fooPubKey, barSecConn.RemotePubKey())
  74. tb.Error(err)
  75. return nil, false, nil
  76. }
  77. return nil, false, nil
  78. },
  79. )
  80. require.Nil(tb, trs.FirstError())
  81. require.True(tb, ok, "Unexpected task abortion")
  82. return fooSecConn, barSecConn
  83. }
  84. func TestSecretConnectionHandshake(t *testing.T) {
  85. fooSecConn, barSecConn := makeSecretConnPair(t)
  86. if err := fooSecConn.Close(); err != nil {
  87. t.Error(err)
  88. }
  89. if err := barSecConn.Close(); err != nil {
  90. t.Error(err)
  91. }
  92. }
  93. func TestConcurrentWrite(t *testing.T) {
  94. fooSecConn, barSecConn := makeSecretConnPair(t)
  95. fooWriteText := rand.RandStr(dataMaxSize)
  96. // write from two routines.
  97. // should be safe from race according to net.Conn:
  98. // https://golang.org/pkg/net/#Conn
  99. n := 100
  100. wg := new(sync.WaitGroup)
  101. wg.Add(3)
  102. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  103. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  104. // Consume reads from bar's reader
  105. readLots(t, wg, barSecConn, n*2)
  106. wg.Wait()
  107. if err := fooSecConn.Close(); err != nil {
  108. t.Error(err)
  109. }
  110. }
  111. func TestConcurrentRead(t *testing.T) {
  112. fooSecConn, barSecConn := makeSecretConnPair(t)
  113. fooWriteText := rand.RandStr(dataMaxSize)
  114. n := 100
  115. // read from two routines.
  116. // should be safe from race according to net.Conn:
  117. // https://golang.org/pkg/net/#Conn
  118. wg := new(sync.WaitGroup)
  119. wg.Add(3)
  120. go readLots(t, wg, fooSecConn, n/2)
  121. go readLots(t, wg, fooSecConn, n/2)
  122. // write to bar
  123. writeLots(t, wg, barSecConn, fooWriteText, n)
  124. wg.Wait()
  125. if err := fooSecConn.Close(); err != nil {
  126. t.Error(err)
  127. }
  128. }
  129. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  130. defer wg.Done()
  131. for i := 0; i < n; i++ {
  132. _, err := conn.Write([]byte(txt))
  133. if err != nil {
  134. t.Errorf("failed to write to fooSecConn: %v", err)
  135. return
  136. }
  137. }
  138. }
  139. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  140. readBuffer := make([]byte, dataMaxSize)
  141. for i := 0; i < n; i++ {
  142. _, err := conn.Read(readBuffer)
  143. assert.NoError(t, err)
  144. }
  145. wg.Done()
  146. }
  147. func TestSecretConnectionReadWrite(t *testing.T) {
  148. fooConn, barConn := makeKVStoreConnPair()
  149. fooWrites, barWrites := []string{}, []string{}
  150. fooReads, barReads := []string{}, []string{}
  151. // Pre-generate the things to write (for foo & bar)
  152. for i := 0; i < 100; i++ {
  153. fooWrites = append(fooWrites, rand.RandStr((rand.RandInt()%(dataMaxSize*5))+1))
  154. barWrites = append(barWrites, rand.RandStr((rand.RandInt()%(dataMaxSize*5))+1))
  155. }
  156. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  157. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  158. return func(_ int) (interface{}, bool, error) {
  159. // Initiate cryptographic private key and secret connection trhough nodeConn.
  160. nodePrvKey := ed25519.GenPrivKey()
  161. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  162. if err != nil {
  163. t.Errorf("failed to establish SecretConnection for node: %v", err)
  164. return nil, true, err
  165. }
  166. // In parallel, handle some reads and writes.
  167. var trs, ok = cmn.Parallel(
  168. func(_ int) (interface{}, bool, error) {
  169. // Node writes:
  170. for _, nodeWrite := range nodeWrites {
  171. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  172. if err != nil {
  173. t.Errorf("failed to write to nodeSecretConn: %v", err)
  174. return nil, true, err
  175. }
  176. if n != len(nodeWrite) {
  177. err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  178. t.Error(err)
  179. return nil, true, err
  180. }
  181. }
  182. if err := nodeConn.PipeWriter.Close(); err != nil {
  183. t.Error(err)
  184. return nil, true, err
  185. }
  186. return nil, false, nil
  187. },
  188. func(_ int) (interface{}, bool, error) {
  189. // Node reads:
  190. readBuffer := make([]byte, dataMaxSize)
  191. for {
  192. n, err := nodeSecretConn.Read(readBuffer)
  193. if err == io.EOF {
  194. if err := nodeConn.PipeReader.Close(); err != nil {
  195. t.Error(err)
  196. return nil, true, err
  197. }
  198. return nil, false, nil
  199. } else if err != nil {
  200. t.Errorf("failed to read from nodeSecretConn: %v", err)
  201. return nil, true, err
  202. }
  203. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  204. }
  205. },
  206. )
  207. assert.True(t, ok, "Unexpected task abortion")
  208. // If error:
  209. if trs.FirstError() != nil {
  210. return nil, true, trs.FirstError()
  211. }
  212. // Otherwise:
  213. return nil, false, nil
  214. }
  215. }
  216. // Run foo & bar in parallel
  217. var trs, ok = cmn.Parallel(
  218. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  219. genNodeRunner("bar", barConn, barWrites, &barReads),
  220. )
  221. require.Nil(t, trs.FirstError())
  222. require.True(t, ok, "unexpected task abortion")
  223. // A helper to ensure that the writes and reads match.
  224. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  225. compareWritesReads := func(writes []string, reads []string) {
  226. for {
  227. // Pop next write & corresponding reads
  228. var read, write string = "", writes[0]
  229. var readCount = 0
  230. for _, readChunk := range reads {
  231. read += readChunk
  232. readCount++
  233. if len(write) <= len(read) {
  234. break
  235. }
  236. if len(write) <= dataMaxSize {
  237. break // atomicity of small writes
  238. }
  239. }
  240. // Compare
  241. if write != read {
  242. t.Errorf("expected to read %X, got %X", write, read)
  243. }
  244. // Iterate
  245. writes = writes[1:]
  246. reads = reads[readCount:]
  247. if len(writes) == 0 {
  248. break
  249. }
  250. }
  251. }
  252. compareWritesReads(fooWrites, barReads)
  253. compareWritesReads(barWrites, fooReads)
  254. }
  255. // Run go test -update from within this module
  256. // to update the golden test vector file
  257. var update = flag.Bool("update", false, "update .golden files")
  258. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  259. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  260. if *update {
  261. t.Logf("Updating golden test vector file %s", goldenFilepath)
  262. data := createGoldenTestVectors(t)
  263. cmn.WriteFile(goldenFilepath, []byte(data), 0644)
  264. }
  265. f, err := os.Open(goldenFilepath)
  266. if err != nil {
  267. log.Fatal(err)
  268. }
  269. defer f.Close()
  270. scanner := bufio.NewScanner(f)
  271. for scanner.Scan() {
  272. line := scanner.Text()
  273. params := strings.Split(line, ",")
  274. randSecretVector, err := hex.DecodeString(params[0])
  275. require.Nil(t, err)
  276. randSecret := new([32]byte)
  277. copy((*randSecret)[:], randSecretVector)
  278. locIsLeast, err := strconv.ParseBool(params[1])
  279. require.Nil(t, err)
  280. expectedRecvSecret, err := hex.DecodeString(params[2])
  281. require.Nil(t, err)
  282. expectedSendSecret, err := hex.DecodeString(params[3])
  283. require.Nil(t, err)
  284. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  285. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  286. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  287. }
  288. }
  289. type privKeyWithNilPubKey struct {
  290. orig crypto.PrivKey
  291. }
  292. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  293. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  294. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  295. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  296. func TestNilPubkey(t *testing.T) {
  297. var fooConn, barConn = makeKVStoreConnPair()
  298. var fooPrvKey = ed25519.GenPrivKey()
  299. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  300. go func() {
  301. _, err := MakeSecretConnection(barConn, barPrvKey)
  302. assert.NoError(t, err)
  303. }()
  304. assert.NotPanics(t, func() {
  305. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  306. if assert.Error(t, err) {
  307. assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
  308. }
  309. })
  310. }
  311. func TestNonEd25519Pubkey(t *testing.T) {
  312. var fooConn, barConn = makeKVStoreConnPair()
  313. var fooPrvKey = ed25519.GenPrivKey()
  314. var barPrvKey = secp256k1.GenPrivKey()
  315. go func() {
  316. _, err := MakeSecretConnection(barConn, barPrvKey)
  317. assert.NoError(t, err)
  318. }()
  319. assert.NotPanics(t, func() {
  320. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  321. if assert.Error(t, err) {
  322. assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
  323. }
  324. })
  325. }
  326. // Creates the data for a test vector file.
  327. // The file format is:
  328. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  329. func createGoldenTestVectors(t *testing.T) string {
  330. data := ""
  331. for i := 0; i < 32; i++ {
  332. randSecretVector := rand.RandBytes(32)
  333. randSecret := new([32]byte)
  334. copy((*randSecret)[:], randSecretVector)
  335. data += hex.EncodeToString((*randSecret)[:]) + ","
  336. locIsLeast := rand.RandBool()
  337. data += strconv.FormatBool(locIsLeast) + ","
  338. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  339. data += hex.EncodeToString((*recvSecret)[:]) + ","
  340. data += hex.EncodeToString((*sendSecret)[:]) + ","
  341. }
  342. return data
  343. }
  344. func BenchmarkWriteSecretConnection(b *testing.B) {
  345. b.StopTimer()
  346. b.ReportAllocs()
  347. fooSecConn, barSecConn := makeSecretConnPair(b)
  348. randomMsgSizes := []int{
  349. dataMaxSize / 10,
  350. dataMaxSize / 3,
  351. dataMaxSize / 2,
  352. dataMaxSize,
  353. dataMaxSize * 3 / 2,
  354. dataMaxSize * 2,
  355. dataMaxSize * 7 / 2,
  356. }
  357. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  358. for _, size := range randomMsgSizes {
  359. fooWriteBytes = append(fooWriteBytes, rand.RandBytes(size))
  360. }
  361. // Consume reads from bar's reader
  362. go func() {
  363. readBuffer := make([]byte, dataMaxSize)
  364. for {
  365. _, err := barSecConn.Read(readBuffer)
  366. if err == io.EOF {
  367. return
  368. } else if err != nil {
  369. b.Errorf("failed to read from barSecConn: %v", err)
  370. return
  371. }
  372. }
  373. }()
  374. b.StartTimer()
  375. for i := 0; i < b.N; i++ {
  376. idx := rand.RandIntn(len(fooWriteBytes))
  377. _, err := fooSecConn.Write(fooWriteBytes[idx])
  378. if err != nil {
  379. b.Errorf("failed to write to fooSecConn: %v", err)
  380. return
  381. }
  382. }
  383. b.StopTimer()
  384. if err := fooSecConn.Close(); err != nil {
  385. b.Error(err)
  386. }
  387. //barSecConn.Close() race condition
  388. }
  389. func BenchmarkReadSecretConnection(b *testing.B) {
  390. b.StopTimer()
  391. b.ReportAllocs()
  392. fooSecConn, barSecConn := makeSecretConnPair(b)
  393. randomMsgSizes := []int{
  394. dataMaxSize / 10,
  395. dataMaxSize / 3,
  396. dataMaxSize / 2,
  397. dataMaxSize,
  398. dataMaxSize * 3 / 2,
  399. dataMaxSize * 2,
  400. dataMaxSize * 7 / 2,
  401. }
  402. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  403. for _, size := range randomMsgSizes {
  404. fooWriteBytes = append(fooWriteBytes, rand.RandBytes(size))
  405. }
  406. go func() {
  407. for i := 0; i < b.N; i++ {
  408. idx := rand.RandIntn(len(fooWriteBytes))
  409. _, err := fooSecConn.Write(fooWriteBytes[idx])
  410. if err != nil {
  411. b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  412. return
  413. }
  414. }
  415. }()
  416. b.StartTimer()
  417. for i := 0; i < b.N; i++ {
  418. readBuffer := make([]byte, dataMaxSize)
  419. _, err := barSecConn.Read(readBuffer)
  420. if err == io.EOF {
  421. return
  422. } else if err != nil {
  423. b.Fatalf("Failed to read from barSecConn: %v", err)
  424. }
  425. }
  426. b.StopTimer()
  427. }