You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

471 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "io/ioutil"
  9. "log"
  10. mrand "math/rand"
  11. "os"
  12. "path/filepath"
  13. "strconv"
  14. "strings"
  15. "sync"
  16. "testing"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/stretchr/testify/require"
  19. "github.com/tendermint/tendermint/crypto"
  20. "github.com/tendermint/tendermint/crypto/ed25519"
  21. "github.com/tendermint/tendermint/crypto/sr25519"
  22. "github.com/tendermint/tendermint/libs/async"
  23. tmrand "github.com/tendermint/tendermint/libs/rand"
  24. )
  25. // Run go test -update from within this module
  26. // to update the golden test vector file
  27. var update = flag.Bool("update", false, "update .golden files")
  28. type kvstoreConn struct {
  29. *io.PipeReader
  30. *io.PipeWriter
  31. }
  32. func (drw kvstoreConn) Close() (err error) {
  33. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  34. err1 := drw.PipeReader.Close()
  35. if err2 != nil {
  36. return err
  37. }
  38. return err1
  39. }
  40. type privKeyWithNilPubKey struct {
  41. orig crypto.PrivKey
  42. }
  43. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  44. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  45. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  46. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  47. func (pk privKeyWithNilPubKey) Type() string { return "privKeyWithNilPubKey" }
  48. func TestSecretConnectionHandshake(t *testing.T) {
  49. fooSecConn, barSecConn := makeSecretConnPair(t)
  50. if err := fooSecConn.Close(); err != nil {
  51. t.Error(err)
  52. }
  53. if err := barSecConn.Close(); err != nil {
  54. t.Error(err)
  55. }
  56. }
  57. func TestConcurrentWrite(t *testing.T) {
  58. fooSecConn, barSecConn := makeSecretConnPair(t)
  59. fooWriteText := tmrand.Str(dataMaxSize)
  60. // write from two routines.
  61. // should be safe from race according to net.Conn:
  62. // https://golang.org/pkg/net/#Conn
  63. n := 100
  64. wg := new(sync.WaitGroup)
  65. wg.Add(3)
  66. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  67. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  68. // Consume reads from bar's reader
  69. readLots(t, wg, barSecConn, n*2)
  70. wg.Wait()
  71. if err := fooSecConn.Close(); err != nil {
  72. t.Error(err)
  73. }
  74. }
  75. func TestConcurrentRead(t *testing.T) {
  76. fooSecConn, barSecConn := makeSecretConnPair(t)
  77. fooWriteText := tmrand.Str(dataMaxSize)
  78. n := 100
  79. // read from two routines.
  80. // should be safe from race according to net.Conn:
  81. // https://golang.org/pkg/net/#Conn
  82. wg := new(sync.WaitGroup)
  83. wg.Add(3)
  84. go readLots(t, wg, fooSecConn, n/2)
  85. go readLots(t, wg, fooSecConn, n/2)
  86. // write to bar
  87. writeLots(t, wg, barSecConn, fooWriteText, n)
  88. wg.Wait()
  89. if err := fooSecConn.Close(); err != nil {
  90. t.Error(err)
  91. }
  92. }
  93. func TestSecretConnectionReadWrite(t *testing.T) {
  94. fooConn, barConn := makeKVStoreConnPair()
  95. fooWrites, barWrites := []string{}, []string{}
  96. fooReads, barReads := []string{}, []string{}
  97. // Pre-generate the things to write (for foo & bar)
  98. for i := 0; i < 100; i++ {
  99. fooWrites = append(fooWrites, tmrand.Str((mrand.Int()%(dataMaxSize*5))+1))
  100. barWrites = append(barWrites, tmrand.Str((mrand.Int()%(dataMaxSize*5))+1))
  101. }
  102. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  103. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
  104. return func(_ int) (interface{}, bool, error) {
  105. // Initiate cryptographic private key and secret connection trhough nodeConn.
  106. nodePrvKey := ed25519.GenPrivKey()
  107. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  108. if err != nil {
  109. t.Errorf("failed to establish SecretConnection for node: %v", err)
  110. return nil, true, err
  111. }
  112. // In parallel, handle some reads and writes.
  113. var trs, ok = async.Parallel(
  114. func(_ int) (interface{}, bool, error) {
  115. // Node writes:
  116. for _, nodeWrite := range nodeWrites {
  117. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  118. if err != nil {
  119. t.Errorf("failed to write to nodeSecretConn: %v", err)
  120. return nil, true, err
  121. }
  122. if n != len(nodeWrite) {
  123. err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  124. t.Error(err)
  125. return nil, true, err
  126. }
  127. }
  128. if err := nodeConn.PipeWriter.Close(); err != nil {
  129. t.Error(err)
  130. return nil, true, err
  131. }
  132. return nil, false, nil
  133. },
  134. func(_ int) (interface{}, bool, error) {
  135. // Node reads:
  136. readBuffer := make([]byte, dataMaxSize)
  137. for {
  138. n, err := nodeSecretConn.Read(readBuffer)
  139. if err == io.EOF {
  140. if err := nodeConn.PipeReader.Close(); err != nil {
  141. t.Error(err)
  142. return nil, true, err
  143. }
  144. return nil, false, nil
  145. } else if err != nil {
  146. t.Errorf("failed to read from nodeSecretConn: %v", err)
  147. return nil, true, err
  148. }
  149. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  150. }
  151. },
  152. )
  153. assert.True(t, ok, "Unexpected task abortion")
  154. // If error:
  155. if trs.FirstError() != nil {
  156. return nil, true, trs.FirstError()
  157. }
  158. // Otherwise:
  159. return nil, false, nil
  160. }
  161. }
  162. // Run foo & bar in parallel
  163. var trs, ok = async.Parallel(
  164. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  165. genNodeRunner("bar", barConn, barWrites, &barReads),
  166. )
  167. require.Nil(t, trs.FirstError())
  168. require.True(t, ok, "unexpected task abortion")
  169. // A helper to ensure that the writes and reads match.
  170. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  171. compareWritesReads := func(writes []string, reads []string) {
  172. for {
  173. // Pop next write & corresponding reads
  174. var read, write string = "", writes[0]
  175. var readCount = 0
  176. for _, readChunk := range reads {
  177. read += readChunk
  178. readCount++
  179. if len(write) <= len(read) {
  180. break
  181. }
  182. if len(write) <= dataMaxSize {
  183. break // atomicity of small writes
  184. }
  185. }
  186. // Compare
  187. if write != read {
  188. t.Errorf("expected to read %X, got %X", write, read)
  189. }
  190. // Iterate
  191. writes = writes[1:]
  192. reads = reads[readCount:]
  193. if len(writes) == 0 {
  194. break
  195. }
  196. }
  197. }
  198. compareWritesReads(fooWrites, barReads)
  199. compareWritesReads(barWrites, fooReads)
  200. }
  201. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  202. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  203. if *update {
  204. t.Logf("Updating golden test vector file %s", goldenFilepath)
  205. data := createGoldenTestVectors(t)
  206. require.NoError(t, ioutil.WriteFile(goldenFilepath, []byte(data), 0644))
  207. }
  208. f, err := os.Open(goldenFilepath)
  209. if err != nil {
  210. log.Fatal(err)
  211. }
  212. t.Cleanup(closeAll(t, f))
  213. scanner := bufio.NewScanner(f)
  214. for scanner.Scan() {
  215. line := scanner.Text()
  216. params := strings.Split(line, ",")
  217. randSecretVector, err := hex.DecodeString(params[0])
  218. require.Nil(t, err)
  219. randSecret := new([32]byte)
  220. copy((*randSecret)[:], randSecretVector)
  221. locIsLeast, err := strconv.ParseBool(params[1])
  222. require.Nil(t, err)
  223. expectedRecvSecret, err := hex.DecodeString(params[2])
  224. require.Nil(t, err)
  225. expectedSendSecret, err := hex.DecodeString(params[3])
  226. require.Nil(t, err)
  227. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  228. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  229. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  230. }
  231. }
  232. func TestNilPubkey(t *testing.T) {
  233. var fooConn, barConn = makeKVStoreConnPair()
  234. t.Cleanup(closeAll(t, fooConn, barConn))
  235. var fooPrvKey = ed25519.GenPrivKey()
  236. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  237. go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
  238. _, err := MakeSecretConnection(barConn, barPrvKey)
  239. require.Error(t, err)
  240. assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
  241. }
  242. func TestNonEd25519Pubkey(t *testing.T) {
  243. var fooConn, barConn = makeKVStoreConnPair()
  244. t.Cleanup(closeAll(t, fooConn, barConn))
  245. var fooPrvKey = ed25519.GenPrivKey()
  246. var barPrvKey = sr25519.GenPrivKey()
  247. go MakeSecretConnection(barConn, barPrvKey) //nolint:errcheck // ignore for tests
  248. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  249. require.Error(t, err)
  250. }
  251. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  252. defer wg.Done()
  253. for i := 0; i < n; i++ {
  254. _, err := conn.Write([]byte(txt))
  255. if err != nil {
  256. t.Errorf("failed to write to fooSecConn: %v", err)
  257. return
  258. }
  259. }
  260. }
  261. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  262. readBuffer := make([]byte, dataMaxSize)
  263. for i := 0; i < n; i++ {
  264. _, err := conn.Read(readBuffer)
  265. assert.NoError(t, err)
  266. }
  267. wg.Done()
  268. }
  269. // Creates the data for a test vector file.
  270. // The file format is:
  271. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  272. func createGoldenTestVectors(t *testing.T) string {
  273. data := ""
  274. for i := 0; i < 32; i++ {
  275. randSecretVector := tmrand.Bytes(32)
  276. randSecret := new([32]byte)
  277. copy((*randSecret)[:], randSecretVector)
  278. data += hex.EncodeToString((*randSecret)[:]) + ","
  279. locIsLeast := mrand.Int63()%2 == 0
  280. data += strconv.FormatBool(locIsLeast) + ","
  281. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  282. data += hex.EncodeToString((*recvSecret)[:]) + ","
  283. data += hex.EncodeToString((*sendSecret)[:]) + ","
  284. }
  285. return data
  286. }
  287. // Each returned ReadWriteCloser is akin to a net.Connection
  288. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  289. barReader, fooWriter := io.Pipe()
  290. fooReader, barWriter := io.Pipe()
  291. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  292. }
  293. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  294. var (
  295. fooConn, barConn = makeKVStoreConnPair()
  296. fooPrvKey = ed25519.GenPrivKey()
  297. fooPubKey = fooPrvKey.PubKey()
  298. barPrvKey = ed25519.GenPrivKey()
  299. barPubKey = barPrvKey.PubKey()
  300. )
  301. // Make connections from both sides in parallel.
  302. var trs, ok = async.Parallel(
  303. func(_ int) (val interface{}, abort bool, err error) {
  304. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  305. if err != nil {
  306. tb.Errorf("failed to establish SecretConnection for foo: %v", err)
  307. return nil, true, err
  308. }
  309. remotePubBytes := fooSecConn.RemotePubKey()
  310. if !remotePubBytes.Equals(barPubKey) {
  311. err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  312. barPubKey, fooSecConn.RemotePubKey())
  313. tb.Error(err)
  314. return nil, true, err
  315. }
  316. return nil, false, nil
  317. },
  318. func(_ int) (val interface{}, abort bool, err error) {
  319. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  320. if barSecConn == nil {
  321. tb.Errorf("failed to establish SecretConnection for bar: %v", err)
  322. return nil, true, err
  323. }
  324. remotePubBytes := barSecConn.RemotePubKey()
  325. if !remotePubBytes.Equals(fooPubKey) {
  326. err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  327. fooPubKey, barSecConn.RemotePubKey())
  328. tb.Error(err)
  329. return nil, true, err
  330. }
  331. return nil, false, nil
  332. },
  333. )
  334. require.Nil(tb, trs.FirstError())
  335. require.True(tb, ok, "Unexpected task abortion")
  336. return fooSecConn, barSecConn
  337. }
  338. // Benchmarks
  339. func BenchmarkWriteSecretConnection(b *testing.B) {
  340. b.StopTimer()
  341. b.ReportAllocs()
  342. fooSecConn, barSecConn := makeSecretConnPair(b)
  343. randomMsgSizes := []int{
  344. dataMaxSize / 10,
  345. dataMaxSize / 3,
  346. dataMaxSize / 2,
  347. dataMaxSize,
  348. dataMaxSize * 3 / 2,
  349. dataMaxSize * 2,
  350. dataMaxSize * 7 / 2,
  351. }
  352. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  353. for _, size := range randomMsgSizes {
  354. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  355. }
  356. // Consume reads from bar's reader
  357. go func() {
  358. readBuffer := make([]byte, dataMaxSize)
  359. for {
  360. _, err := barSecConn.Read(readBuffer)
  361. if err == io.EOF {
  362. return
  363. } else if err != nil {
  364. b.Errorf("failed to read from barSecConn: %v", err)
  365. return
  366. }
  367. }
  368. }()
  369. b.StartTimer()
  370. for i := 0; i < b.N; i++ {
  371. idx := mrand.Intn(len(fooWriteBytes))
  372. _, err := fooSecConn.Write(fooWriteBytes[idx])
  373. if err != nil {
  374. b.Errorf("failed to write to fooSecConn: %v", err)
  375. return
  376. }
  377. }
  378. b.StopTimer()
  379. if err := fooSecConn.Close(); err != nil {
  380. b.Error(err)
  381. }
  382. // barSecConn.Close() race condition
  383. }
  384. func BenchmarkReadSecretConnection(b *testing.B) {
  385. b.StopTimer()
  386. b.ReportAllocs()
  387. fooSecConn, barSecConn := makeSecretConnPair(b)
  388. randomMsgSizes := []int{
  389. dataMaxSize / 10,
  390. dataMaxSize / 3,
  391. dataMaxSize / 2,
  392. dataMaxSize,
  393. dataMaxSize * 3 / 2,
  394. dataMaxSize * 2,
  395. dataMaxSize * 7 / 2,
  396. }
  397. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  398. for _, size := range randomMsgSizes {
  399. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  400. }
  401. go func() {
  402. for i := 0; i < b.N; i++ {
  403. idx := mrand.Intn(len(fooWriteBytes))
  404. _, err := fooSecConn.Write(fooWriteBytes[idx])
  405. if err != nil {
  406. b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  407. return
  408. }
  409. }
  410. }()
  411. b.StartTimer()
  412. for i := 0; i < b.N; i++ {
  413. readBuffer := make([]byte, dataMaxSize)
  414. _, err := barSecConn.Read(readBuffer)
  415. if err == io.EOF {
  416. return
  417. } else if err != nil {
  418. b.Fatalf("Failed to read from barSecConn: %v", err)
  419. }
  420. }
  421. b.StopTimer()
  422. }