You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

480 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. "github.com/tendermint/tendermint/crypto/secp256k1"
  20. "github.com/tendermint/tendermint/libs/async"
  21. tmos "github.com/tendermint/tendermint/libs/os"
  22. tmrand "github.com/tendermint/tendermint/libs/rand"
  23. )
  24. // Run go test -update from within this module
  25. // to update the golden test vector file
  26. var update = flag.Bool("update", false, "update .golden files")
  27. type kvstoreConn struct {
  28. *io.PipeReader
  29. *io.PipeWriter
  30. }
  31. func (drw kvstoreConn) Close() (err error) {
  32. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  33. err1 := drw.PipeReader.Close()
  34. if err2 != nil {
  35. return err
  36. }
  37. return err1
  38. }
  39. type privKeyWithNilPubKey struct {
  40. orig crypto.PrivKey
  41. }
  42. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  43. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  44. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  45. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  46. func TestSecretConnectionHandshake(t *testing.T) {
  47. fooSecConn, barSecConn := makeSecretConnPair(t)
  48. if err := fooSecConn.Close(); err != nil {
  49. t.Error(err)
  50. }
  51. if err := barSecConn.Close(); err != nil {
  52. t.Error(err)
  53. }
  54. }
  55. func TestConcurrentWrite(t *testing.T) {
  56. fooSecConn, barSecConn := makeSecretConnPair(t)
  57. fooWriteText := tmrand.Str(dataMaxSize)
  58. // write from two routines.
  59. // should be safe from race according to net.Conn:
  60. // https://golang.org/pkg/net/#Conn
  61. n := 100
  62. wg := new(sync.WaitGroup)
  63. wg.Add(3)
  64. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  65. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  66. // Consume reads from bar's reader
  67. readLots(t, wg, barSecConn, n*2)
  68. wg.Wait()
  69. if err := fooSecConn.Close(); err != nil {
  70. t.Error(err)
  71. }
  72. }
  73. func TestConcurrentRead(t *testing.T) {
  74. fooSecConn, barSecConn := makeSecretConnPair(t)
  75. fooWriteText := tmrand.Str(dataMaxSize)
  76. n := 100
  77. // read from two routines.
  78. // should be safe from race according to net.Conn:
  79. // https://golang.org/pkg/net/#Conn
  80. wg := new(sync.WaitGroup)
  81. wg.Add(3)
  82. go readLots(t, wg, fooSecConn, n/2)
  83. go readLots(t, wg, fooSecConn, n/2)
  84. // write to bar
  85. writeLots(t, wg, barSecConn, fooWriteText, n)
  86. wg.Wait()
  87. if err := fooSecConn.Close(); err != nil {
  88. t.Error(err)
  89. }
  90. }
  91. func TestSecretConnectionReadWrite(t *testing.T) {
  92. fooConn, barConn := makeKVStoreConnPair()
  93. fooWrites, barWrites := []string{}, []string{}
  94. fooReads, barReads := []string{}, []string{}
  95. // Pre-generate the things to write (for foo & bar)
  96. for i := 0; i < 100; i++ {
  97. fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  98. barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  99. }
  100. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  101. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
  102. return func(_ int) (interface{}, bool, error) {
  103. // Initiate cryptographic private key and secret connection trhough nodeConn.
  104. nodePrvKey := ed25519.GenPrivKey()
  105. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  106. if err != nil {
  107. t.Errorf("failed to establish SecretConnection for node: %v", err)
  108. return nil, true, err
  109. }
  110. // In parallel, handle some reads and writes.
  111. var trs, ok = async.Parallel(
  112. func(_ int) (interface{}, bool, error) {
  113. // Node writes:
  114. for _, nodeWrite := range nodeWrites {
  115. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  116. if err != nil {
  117. t.Errorf("failed to write to nodeSecretConn: %v", err)
  118. return nil, true, err
  119. }
  120. if n != len(nodeWrite) {
  121. err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  122. t.Error(err)
  123. return nil, true, err
  124. }
  125. }
  126. if err := nodeConn.PipeWriter.Close(); err != nil {
  127. t.Error(err)
  128. return nil, true, err
  129. }
  130. return nil, false, nil
  131. },
  132. func(_ int) (interface{}, bool, error) {
  133. // Node reads:
  134. readBuffer := make([]byte, dataMaxSize)
  135. for {
  136. n, err := nodeSecretConn.Read(readBuffer)
  137. if err == io.EOF {
  138. if err := nodeConn.PipeReader.Close(); err != nil {
  139. t.Error(err)
  140. return nil, true, err
  141. }
  142. return nil, false, nil
  143. } else if err != nil {
  144. t.Errorf("failed to read from nodeSecretConn: %v", err)
  145. return nil, true, err
  146. }
  147. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  148. }
  149. },
  150. )
  151. assert.True(t, ok, "Unexpected task abortion")
  152. // If error:
  153. if trs.FirstError() != nil {
  154. return nil, true, trs.FirstError()
  155. }
  156. // Otherwise:
  157. return nil, false, nil
  158. }
  159. }
  160. // Run foo & bar in parallel
  161. var trs, ok = async.Parallel(
  162. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  163. genNodeRunner("bar", barConn, barWrites, &barReads),
  164. )
  165. require.Nil(t, trs.FirstError())
  166. require.True(t, ok, "unexpected task abortion")
  167. // A helper to ensure that the writes and reads match.
  168. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  169. compareWritesReads := func(writes []string, reads []string) {
  170. for {
  171. // Pop next write & corresponding reads
  172. var read, write string = "", writes[0]
  173. var readCount = 0
  174. for _, readChunk := range reads {
  175. read += readChunk
  176. readCount++
  177. if len(write) <= len(read) {
  178. break
  179. }
  180. if len(write) <= dataMaxSize {
  181. break // atomicity of small writes
  182. }
  183. }
  184. // Compare
  185. if write != read {
  186. t.Errorf("expected to read %X, got %X", write, read)
  187. }
  188. // Iterate
  189. writes = writes[1:]
  190. reads = reads[readCount:]
  191. if len(writes) == 0 {
  192. break
  193. }
  194. }
  195. }
  196. compareWritesReads(fooWrites, barReads)
  197. compareWritesReads(barWrites, fooReads)
  198. }
  199. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  200. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  201. if *update {
  202. t.Logf("Updating golden test vector file %s", goldenFilepath)
  203. data := createGoldenTestVectors(t)
  204. tmos.WriteFile(goldenFilepath, []byte(data), 0644)
  205. }
  206. f, err := os.Open(goldenFilepath)
  207. if err != nil {
  208. log.Fatal(err)
  209. }
  210. defer f.Close()
  211. scanner := bufio.NewScanner(f)
  212. for scanner.Scan() {
  213. line := scanner.Text()
  214. params := strings.Split(line, ",")
  215. randSecretVector, err := hex.DecodeString(params[0])
  216. require.Nil(t, err)
  217. randSecret := new([32]byte)
  218. copy((*randSecret)[:], randSecretVector)
  219. locIsLeast, err := strconv.ParseBool(params[1])
  220. require.Nil(t, err)
  221. expectedRecvSecret, err := hex.DecodeString(params[2])
  222. require.Nil(t, err)
  223. expectedSendSecret, err := hex.DecodeString(params[3])
  224. require.Nil(t, err)
  225. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  226. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  227. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  228. }
  229. }
  230. func TestNilPubkey(t *testing.T) {
  231. var fooConn, barConn = makeKVStoreConnPair()
  232. var fooPrvKey = ed25519.GenPrivKey()
  233. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  234. go func() {
  235. _, err := MakeSecretConnection(barConn, barPrvKey)
  236. assert.NoError(t, err)
  237. }()
  238. assert.NotPanics(t, func() {
  239. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  240. if assert.Error(t, err) {
  241. assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
  242. }
  243. })
  244. }
  245. func TestNonEd25519Pubkey(t *testing.T) {
  246. var fooConn, barConn = makeKVStoreConnPair()
  247. var fooPrvKey = ed25519.GenPrivKey()
  248. var barPrvKey = secp256k1.GenPrivKey()
  249. go func() {
  250. _, err := MakeSecretConnection(barConn, barPrvKey)
  251. assert.NoError(t, err)
  252. }()
  253. assert.NotPanics(t, func() {
  254. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  255. if assert.Error(t, err) {
  256. assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKey", err.Error())
  257. }
  258. })
  259. }
  260. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  261. defer wg.Done()
  262. for i := 0; i < n; i++ {
  263. _, err := conn.Write([]byte(txt))
  264. if err != nil {
  265. t.Errorf("failed to write to fooSecConn: %v", err)
  266. return
  267. }
  268. }
  269. }
  270. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  271. readBuffer := make([]byte, dataMaxSize)
  272. for i := 0; i < n; i++ {
  273. _, err := conn.Read(readBuffer)
  274. assert.NoError(t, err)
  275. }
  276. wg.Done()
  277. }
  278. // Creates the data for a test vector file.
  279. // The file format is:
  280. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  281. func createGoldenTestVectors(t *testing.T) string {
  282. data := ""
  283. for i := 0; i < 32; i++ {
  284. randSecretVector := tmrand.Bytes(32)
  285. randSecret := new([32]byte)
  286. copy((*randSecret)[:], randSecretVector)
  287. data += hex.EncodeToString((*randSecret)[:]) + ","
  288. locIsLeast := tmrand.Bool()
  289. data += strconv.FormatBool(locIsLeast) + ","
  290. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  291. data += hex.EncodeToString((*recvSecret)[:]) + ","
  292. data += hex.EncodeToString((*sendSecret)[:]) + ","
  293. }
  294. return data
  295. }
  296. // Each returned ReadWriteCloser is akin to a net.Connection
  297. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  298. barReader, fooWriter := io.Pipe()
  299. fooReader, barWriter := io.Pipe()
  300. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  301. }
  302. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  303. var (
  304. fooConn, barConn = makeKVStoreConnPair()
  305. fooPrvKey = ed25519.GenPrivKey()
  306. fooPubKey = fooPrvKey.PubKey()
  307. barPrvKey = ed25519.GenPrivKey()
  308. barPubKey = barPrvKey.PubKey()
  309. )
  310. // Make connections from both sides in parallel.
  311. var trs, ok = async.Parallel(
  312. func(_ int) (val interface{}, abort bool, err error) {
  313. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  314. if err != nil {
  315. tb.Errorf("failed to establish SecretConnection for foo: %v", err)
  316. return nil, true, err
  317. }
  318. remotePubBytes := fooSecConn.RemotePubKey()
  319. if !remotePubBytes.Equals(barPubKey) {
  320. err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  321. barPubKey, fooSecConn.RemotePubKey())
  322. tb.Error(err)
  323. return nil, true, err
  324. }
  325. return nil, false, nil
  326. },
  327. func(_ int) (val interface{}, abort bool, err error) {
  328. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  329. if barSecConn == nil {
  330. tb.Errorf("failed to establish SecretConnection for bar: %v", err)
  331. return nil, true, err
  332. }
  333. remotePubBytes := barSecConn.RemotePubKey()
  334. if !remotePubBytes.Equals(fooPubKey) {
  335. err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  336. fooPubKey, barSecConn.RemotePubKey())
  337. tb.Error(err)
  338. return nil, true, err
  339. }
  340. return nil, false, nil
  341. },
  342. )
  343. require.Nil(tb, trs.FirstError())
  344. require.True(tb, ok, "Unexpected task abortion")
  345. return fooSecConn, barSecConn
  346. }
  347. ///////////////////////////////////////////////////////////////////////////////
  348. // Benchmarks
  349. func BenchmarkWriteSecretConnection(b *testing.B) {
  350. b.StopTimer()
  351. b.ReportAllocs()
  352. fooSecConn, barSecConn := makeSecretConnPair(b)
  353. randomMsgSizes := []int{
  354. dataMaxSize / 10,
  355. dataMaxSize / 3,
  356. dataMaxSize / 2,
  357. dataMaxSize,
  358. dataMaxSize * 3 / 2,
  359. dataMaxSize * 2,
  360. dataMaxSize * 7 / 2,
  361. }
  362. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  363. for _, size := range randomMsgSizes {
  364. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  365. }
  366. // Consume reads from bar's reader
  367. go func() {
  368. readBuffer := make([]byte, dataMaxSize)
  369. for {
  370. _, err := barSecConn.Read(readBuffer)
  371. if err == io.EOF {
  372. return
  373. } else if err != nil {
  374. b.Errorf("failed to read from barSecConn: %v", err)
  375. return
  376. }
  377. }
  378. }()
  379. b.StartTimer()
  380. for i := 0; i < b.N; i++ {
  381. idx := tmrand.Intn(len(fooWriteBytes))
  382. _, err := fooSecConn.Write(fooWriteBytes[idx])
  383. if err != nil {
  384. b.Errorf("failed to write to fooSecConn: %v", err)
  385. return
  386. }
  387. }
  388. b.StopTimer()
  389. if err := fooSecConn.Close(); err != nil {
  390. b.Error(err)
  391. }
  392. //barSecConn.Close() race condition
  393. }
  394. func BenchmarkReadSecretConnection(b *testing.B) {
  395. b.StopTimer()
  396. b.ReportAllocs()
  397. fooSecConn, barSecConn := makeSecretConnPair(b)
  398. randomMsgSizes := []int{
  399. dataMaxSize / 10,
  400. dataMaxSize / 3,
  401. dataMaxSize / 2,
  402. dataMaxSize,
  403. dataMaxSize * 3 / 2,
  404. dataMaxSize * 2,
  405. dataMaxSize * 7 / 2,
  406. }
  407. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  408. for _, size := range randomMsgSizes {
  409. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  410. }
  411. go func() {
  412. for i := 0; i < b.N; i++ {
  413. idx := tmrand.Intn(len(fooWriteBytes))
  414. _, err := fooSecConn.Write(fooWriteBytes[idx])
  415. if err != nil {
  416. b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  417. return
  418. }
  419. }
  420. }()
  421. b.StartTimer()
  422. for i := 0; i < b.N; i++ {
  423. readBuffer := make([]byte, dataMaxSize)
  424. _, err := barSecConn.Read(readBuffer)
  425. if err == io.EOF {
  426. return
  427. } else if err != nil {
  428. b.Fatalf("Failed to read from barSecConn: %v", err)
  429. }
  430. }
  431. b.StopTimer()
  432. }