You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

474 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. "github.com/tendermint/tendermint/crypto/sr25519"
  20. "github.com/tendermint/tendermint/libs/async"
  21. tmos "github.com/tendermint/tendermint/libs/os"
  22. tmrand "github.com/tendermint/tendermint/libs/rand"
  23. )
  24. // Run go test -update from within this module
  25. // to update the golden test vector file
  26. var update = flag.Bool("update", false, "update .golden files")
  27. type kvstoreConn struct {
  28. *io.PipeReader
  29. *io.PipeWriter
  30. }
  31. func (drw kvstoreConn) Close() (err error) {
  32. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  33. err1 := drw.PipeReader.Close()
  34. if err2 != nil {
  35. return err
  36. }
  37. return err1
  38. }
  39. type privKeyWithNilPubKey struct {
  40. orig crypto.PrivKey
  41. }
  42. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  43. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  44. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  45. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  46. func (pk privKeyWithNilPubKey) Type() string { return "privKeyWithNilPubKey" }
  47. func TestSecretConnectionHandshake(t *testing.T) {
  48. fooSecConn, barSecConn := makeSecretConnPair(t)
  49. if err := fooSecConn.Close(); err != nil {
  50. t.Error(err)
  51. }
  52. if err := barSecConn.Close(); err != nil {
  53. t.Error(err)
  54. }
  55. }
  56. func TestConcurrentWrite(t *testing.T) {
  57. fooSecConn, barSecConn := makeSecretConnPair(t)
  58. fooWriteText := tmrand.Str(dataMaxSize)
  59. // write from two routines.
  60. // should be safe from race according to net.Conn:
  61. // https://golang.org/pkg/net/#Conn
  62. n := 100
  63. wg := new(sync.WaitGroup)
  64. wg.Add(3)
  65. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  66. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  67. // Consume reads from bar's reader
  68. readLots(t, wg, barSecConn, n*2)
  69. wg.Wait()
  70. if err := fooSecConn.Close(); err != nil {
  71. t.Error(err)
  72. }
  73. }
  74. func TestConcurrentRead(t *testing.T) {
  75. fooSecConn, barSecConn := makeSecretConnPair(t)
  76. fooWriteText := tmrand.Str(dataMaxSize)
  77. n := 100
  78. // read from two routines.
  79. // should be safe from race according to net.Conn:
  80. // https://golang.org/pkg/net/#Conn
  81. wg := new(sync.WaitGroup)
  82. wg.Add(3)
  83. go readLots(t, wg, fooSecConn, n/2)
  84. go readLots(t, wg, fooSecConn, n/2)
  85. // write to bar
  86. writeLots(t, wg, barSecConn, fooWriteText, n)
  87. wg.Wait()
  88. if err := fooSecConn.Close(); err != nil {
  89. t.Error(err)
  90. }
  91. }
  92. func TestSecretConnectionReadWrite(t *testing.T) {
  93. fooConn, barConn := makeKVStoreConnPair()
  94. fooWrites, barWrites := []string{}, []string{}
  95. fooReads, barReads := []string{}, []string{}
  96. // Pre-generate the things to write (for foo & bar)
  97. for i := 0; i < 100; i++ {
  98. fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  99. barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  100. }
  101. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  102. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
  103. return func(_ int) (interface{}, bool, error) {
  104. // Initiate cryptographic private key and secret connection trhough nodeConn.
  105. nodePrvKey := ed25519.GenPrivKey()
  106. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  107. if err != nil {
  108. t.Errorf("failed to establish SecretConnection for node: %v", err)
  109. return nil, true, err
  110. }
  111. // In parallel, handle some reads and writes.
  112. var trs, ok = async.Parallel(
  113. func(_ int) (interface{}, bool, error) {
  114. // Node writes:
  115. for _, nodeWrite := range nodeWrites {
  116. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  117. if err != nil {
  118. t.Errorf("failed to write to nodeSecretConn: %v", err)
  119. return nil, true, err
  120. }
  121. if n != len(nodeWrite) {
  122. err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  123. t.Error(err)
  124. return nil, true, err
  125. }
  126. }
  127. if err := nodeConn.PipeWriter.Close(); err != nil {
  128. t.Error(err)
  129. return nil, true, err
  130. }
  131. return nil, false, nil
  132. },
  133. func(_ int) (interface{}, bool, error) {
  134. // Node reads:
  135. readBuffer := make([]byte, dataMaxSize)
  136. for {
  137. n, err := nodeSecretConn.Read(readBuffer)
  138. if err == io.EOF {
  139. if err := nodeConn.PipeReader.Close(); err != nil {
  140. t.Error(err)
  141. return nil, true, err
  142. }
  143. return nil, false, nil
  144. } else if err != nil {
  145. t.Errorf("failed to read from nodeSecretConn: %v", err)
  146. return nil, true, err
  147. }
  148. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  149. }
  150. },
  151. )
  152. assert.True(t, ok, "Unexpected task abortion")
  153. // If error:
  154. if trs.FirstError() != nil {
  155. return nil, true, trs.FirstError()
  156. }
  157. // Otherwise:
  158. return nil, false, nil
  159. }
  160. }
  161. // Run foo & bar in parallel
  162. var trs, ok = async.Parallel(
  163. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  164. genNodeRunner("bar", barConn, barWrites, &barReads),
  165. )
  166. require.Nil(t, trs.FirstError())
  167. require.True(t, ok, "unexpected task abortion")
  168. // A helper to ensure that the writes and reads match.
  169. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  170. compareWritesReads := func(writes []string, reads []string) {
  171. for {
  172. // Pop next write & corresponding reads
  173. var read, write string = "", writes[0]
  174. var readCount = 0
  175. for _, readChunk := range reads {
  176. read += readChunk
  177. readCount++
  178. if len(write) <= len(read) {
  179. break
  180. }
  181. if len(write) <= dataMaxSize {
  182. break // atomicity of small writes
  183. }
  184. }
  185. // Compare
  186. if write != read {
  187. t.Errorf("expected to read %X, got %X", write, read)
  188. }
  189. // Iterate
  190. writes = writes[1:]
  191. reads = reads[readCount:]
  192. if len(writes) == 0 {
  193. break
  194. }
  195. }
  196. }
  197. compareWritesReads(fooWrites, barReads)
  198. compareWritesReads(barWrites, fooReads)
  199. }
  200. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  201. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  202. if *update {
  203. t.Logf("Updating golden test vector file %s", goldenFilepath)
  204. data := createGoldenTestVectors(t)
  205. err := tmos.WriteFile(goldenFilepath, []byte(data), 0644)
  206. require.NoError(t, err)
  207. }
  208. f, err := os.Open(goldenFilepath)
  209. if err != nil {
  210. log.Fatal(err)
  211. }
  212. defer f.Close()
  213. scanner := bufio.NewScanner(f)
  214. for scanner.Scan() {
  215. line := scanner.Text()
  216. params := strings.Split(line, ",")
  217. randSecretVector, err := hex.DecodeString(params[0])
  218. require.Nil(t, err)
  219. randSecret := new([32]byte)
  220. copy((*randSecret)[:], randSecretVector)
  221. locIsLeast, err := strconv.ParseBool(params[1])
  222. require.Nil(t, err)
  223. expectedRecvSecret, err := hex.DecodeString(params[2])
  224. require.Nil(t, err)
  225. expectedSendSecret, err := hex.DecodeString(params[3])
  226. require.Nil(t, err)
  227. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  228. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  229. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  230. }
  231. }
  232. func TestNilPubkey(t *testing.T) {
  233. var fooConn, barConn = makeKVStoreConnPair()
  234. defer fooConn.Close()
  235. defer barConn.Close()
  236. var fooPrvKey = ed25519.GenPrivKey()
  237. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  238. go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
  239. _, err := MakeSecretConnection(barConn, barPrvKey)
  240. require.Error(t, err)
  241. assert.Equal(t, "toproto: key type <nil> is not supported", err.Error())
  242. }
  243. func TestNonEd25519Pubkey(t *testing.T) {
  244. var fooConn, barConn = makeKVStoreConnPair()
  245. defer fooConn.Close()
  246. defer barConn.Close()
  247. var fooPrvKey = ed25519.GenPrivKey()
  248. var barPrvKey = sr25519.GenPrivKey()
  249. go MakeSecretConnection(fooConn, fooPrvKey) //nolint:errcheck // ignore for tests
  250. _, err := MakeSecretConnection(barConn, barPrvKey)
  251. require.Error(t, err)
  252. assert.Contains(t, err.Error(), "is not supported")
  253. }
  254. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  255. defer wg.Done()
  256. for i := 0; i < n; i++ {
  257. _, err := conn.Write([]byte(txt))
  258. if err != nil {
  259. t.Errorf("failed to write to fooSecConn: %v", err)
  260. return
  261. }
  262. }
  263. }
  264. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  265. readBuffer := make([]byte, dataMaxSize)
  266. for i := 0; i < n; i++ {
  267. _, err := conn.Read(readBuffer)
  268. assert.NoError(t, err)
  269. }
  270. wg.Done()
  271. }
  272. // Creates the data for a test vector file.
  273. // The file format is:
  274. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  275. func createGoldenTestVectors(t *testing.T) string {
  276. data := ""
  277. for i := 0; i < 32; i++ {
  278. randSecretVector := tmrand.Bytes(32)
  279. randSecret := new([32]byte)
  280. copy((*randSecret)[:], randSecretVector)
  281. data += hex.EncodeToString((*randSecret)[:]) + ","
  282. locIsLeast := tmrand.Bool()
  283. data += strconv.FormatBool(locIsLeast) + ","
  284. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  285. data += hex.EncodeToString((*recvSecret)[:]) + ","
  286. data += hex.EncodeToString((*sendSecret)[:]) + ","
  287. }
  288. return data
  289. }
  290. // Each returned ReadWriteCloser is akin to a net.Connection
  291. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  292. barReader, fooWriter := io.Pipe()
  293. fooReader, barWriter := io.Pipe()
  294. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  295. }
  296. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  297. var (
  298. fooConn, barConn = makeKVStoreConnPair()
  299. fooPrvKey = ed25519.GenPrivKey()
  300. fooPubKey = fooPrvKey.PubKey()
  301. barPrvKey = ed25519.GenPrivKey()
  302. barPubKey = barPrvKey.PubKey()
  303. )
  304. // Make connections from both sides in parallel.
  305. var trs, ok = async.Parallel(
  306. func(_ int) (val interface{}, abort bool, err error) {
  307. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  308. if err != nil {
  309. tb.Errorf("failed to establish SecretConnection for foo: %v", err)
  310. return nil, true, err
  311. }
  312. remotePubBytes := fooSecConn.RemotePubKey()
  313. if !remotePubBytes.Equals(barPubKey) {
  314. err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  315. barPubKey, fooSecConn.RemotePubKey())
  316. tb.Error(err)
  317. return nil, true, err
  318. }
  319. return nil, false, nil
  320. },
  321. func(_ int) (val interface{}, abort bool, err error) {
  322. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  323. if barSecConn == nil {
  324. tb.Errorf("failed to establish SecretConnection for bar: %v", err)
  325. return nil, true, err
  326. }
  327. remotePubBytes := barSecConn.RemotePubKey()
  328. if !remotePubBytes.Equals(fooPubKey) {
  329. err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  330. fooPubKey, barSecConn.RemotePubKey())
  331. tb.Error(err)
  332. return nil, true, err
  333. }
  334. return nil, false, nil
  335. },
  336. )
  337. require.Nil(tb, trs.FirstError())
  338. require.True(tb, ok, "Unexpected task abortion")
  339. return fooSecConn, barSecConn
  340. }
  341. ///////////////////////////////////////////////////////////////////////////////
  342. // Benchmarks
  343. func BenchmarkWriteSecretConnection(b *testing.B) {
  344. b.StopTimer()
  345. b.ReportAllocs()
  346. fooSecConn, barSecConn := makeSecretConnPair(b)
  347. randomMsgSizes := []int{
  348. dataMaxSize / 10,
  349. dataMaxSize / 3,
  350. dataMaxSize / 2,
  351. dataMaxSize,
  352. dataMaxSize * 3 / 2,
  353. dataMaxSize * 2,
  354. dataMaxSize * 7 / 2,
  355. }
  356. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  357. for _, size := range randomMsgSizes {
  358. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  359. }
  360. // Consume reads from bar's reader
  361. go func() {
  362. readBuffer := make([]byte, dataMaxSize)
  363. for {
  364. _, err := barSecConn.Read(readBuffer)
  365. if err == io.EOF {
  366. return
  367. } else if err != nil {
  368. b.Errorf("failed to read from barSecConn: %v", err)
  369. return
  370. }
  371. }
  372. }()
  373. b.StartTimer()
  374. for i := 0; i < b.N; i++ {
  375. idx := tmrand.Intn(len(fooWriteBytes))
  376. _, err := fooSecConn.Write(fooWriteBytes[idx])
  377. if err != nil {
  378. b.Errorf("failed to write to fooSecConn: %v", err)
  379. return
  380. }
  381. }
  382. b.StopTimer()
  383. if err := fooSecConn.Close(); err != nil {
  384. b.Error(err)
  385. }
  386. //barSecConn.Close() race condition
  387. }
  388. func BenchmarkReadSecretConnection(b *testing.B) {
  389. b.StopTimer()
  390. b.ReportAllocs()
  391. fooSecConn, barSecConn := makeSecretConnPair(b)
  392. randomMsgSizes := []int{
  393. dataMaxSize / 10,
  394. dataMaxSize / 3,
  395. dataMaxSize / 2,
  396. dataMaxSize,
  397. dataMaxSize * 3 / 2,
  398. dataMaxSize * 2,
  399. dataMaxSize * 7 / 2,
  400. }
  401. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  402. for _, size := range randomMsgSizes {
  403. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  404. }
  405. go func() {
  406. for i := 0; i < b.N; i++ {
  407. idx := tmrand.Intn(len(fooWriteBytes))
  408. _, err := fooSecConn.Write(fooWriteBytes[idx])
  409. if err != nil {
  410. b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  411. return
  412. }
  413. }
  414. }()
  415. b.StartTimer()
  416. for i := 0; i < b.N; i++ {
  417. readBuffer := make([]byte, dataMaxSize)
  418. _, err := barSecConn.Read(readBuffer)
  419. if err == io.EOF {
  420. return
  421. } else if err != nil {
  422. b.Fatalf("Failed to read from barSecConn: %v", err)
  423. }
  424. }
  425. b.StopTimer()
  426. }