You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

476 lines
13 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "os"
  10. "path/filepath"
  11. "strconv"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "github.com/stretchr/testify/assert"
  16. "github.com/stretchr/testify/require"
  17. "github.com/tendermint/tendermint/crypto"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. "github.com/tendermint/tendermint/crypto/secp256k1"
  20. "github.com/tendermint/tendermint/libs/async"
  21. tmos "github.com/tendermint/tendermint/libs/os"
  22. tmrand "github.com/tendermint/tendermint/libs/rand"
  23. )
  24. type kvstoreConn struct {
  25. *io.PipeReader
  26. *io.PipeWriter
  27. }
  28. func (drw kvstoreConn) Close() (err error) {
  29. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  30. err1 := drw.PipeReader.Close()
  31. if err2 != nil {
  32. return err
  33. }
  34. return err1
  35. }
  36. // Each returned ReadWriteCloser is akin to a net.Connection
  37. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  38. barReader, fooWriter := io.Pipe()
  39. fooReader, barWriter := io.Pipe()
  40. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  41. }
  42. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  43. var fooConn, barConn = makeKVStoreConnPair()
  44. var fooPrvKey = ed25519.GenPrivKey()
  45. var fooPubKey = fooPrvKey.PubKey()
  46. var barPrvKey = ed25519.GenPrivKey()
  47. var barPubKey = barPrvKey.PubKey()
  48. // Make connections from both sides in parallel.
  49. var trs, ok = async.Parallel(
  50. func(_ int) (val interface{}, abort bool, err error) {
  51. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  52. if err != nil {
  53. tb.Errorf("failed to establish SecretConnection for foo: %v", err)
  54. return nil, true, err
  55. }
  56. remotePubBytes := fooSecConn.RemotePubKey()
  57. if !remotePubBytes.Equals(barPubKey) {
  58. err = fmt.Errorf("unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  59. barPubKey, fooSecConn.RemotePubKey())
  60. tb.Error(err)
  61. return nil, false, err
  62. }
  63. return nil, false, nil
  64. },
  65. func(_ int) (val interface{}, abort bool, err error) {
  66. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  67. if barSecConn == nil {
  68. tb.Errorf("failed to establish SecretConnection for bar: %v", err)
  69. return nil, true, err
  70. }
  71. remotePubBytes := barSecConn.RemotePubKey()
  72. if !remotePubBytes.Equals(fooPubKey) {
  73. err = fmt.Errorf("unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  74. fooPubKey, barSecConn.RemotePubKey())
  75. tb.Error(err)
  76. return nil, false, nil
  77. }
  78. return nil, false, nil
  79. },
  80. )
  81. require.Nil(tb, trs.FirstError())
  82. require.True(tb, ok, "Unexpected task abortion")
  83. return fooSecConn, barSecConn
  84. }
  85. func TestSecretConnectionHandshake(t *testing.T) {
  86. fooSecConn, barSecConn := makeSecretConnPair(t)
  87. if err := fooSecConn.Close(); err != nil {
  88. t.Error(err)
  89. }
  90. if err := barSecConn.Close(); err != nil {
  91. t.Error(err)
  92. }
  93. }
  94. func TestConcurrentWrite(t *testing.T) {
  95. fooSecConn, barSecConn := makeSecretConnPair(t)
  96. fooWriteText := tmrand.Str(dataMaxSize)
  97. // write from two routines.
  98. // should be safe from race according to net.Conn:
  99. // https://golang.org/pkg/net/#Conn
  100. n := 100
  101. wg := new(sync.WaitGroup)
  102. wg.Add(3)
  103. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  104. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  105. // Consume reads from bar's reader
  106. readLots(t, wg, barSecConn, n*2)
  107. wg.Wait()
  108. if err := fooSecConn.Close(); err != nil {
  109. t.Error(err)
  110. }
  111. }
  112. func TestConcurrentRead(t *testing.T) {
  113. fooSecConn, barSecConn := makeSecretConnPair(t)
  114. fooWriteText := tmrand.Str(dataMaxSize)
  115. n := 100
  116. // read from two routines.
  117. // should be safe from race according to net.Conn:
  118. // https://golang.org/pkg/net/#Conn
  119. wg := new(sync.WaitGroup)
  120. wg.Add(3)
  121. go readLots(t, wg, fooSecConn, n/2)
  122. go readLots(t, wg, fooSecConn, n/2)
  123. // write to bar
  124. writeLots(t, wg, barSecConn, fooWriteText, n)
  125. wg.Wait()
  126. if err := fooSecConn.Close(); err != nil {
  127. t.Error(err)
  128. }
  129. }
  130. func writeLots(t *testing.T, wg *sync.WaitGroup, conn io.Writer, txt string, n int) {
  131. defer wg.Done()
  132. for i := 0; i < n; i++ {
  133. _, err := conn.Write([]byte(txt))
  134. if err != nil {
  135. t.Errorf("failed to write to fooSecConn: %v", err)
  136. return
  137. }
  138. }
  139. }
  140. func readLots(t *testing.T, wg *sync.WaitGroup, conn io.Reader, n int) {
  141. readBuffer := make([]byte, dataMaxSize)
  142. for i := 0; i < n; i++ {
  143. _, err := conn.Read(readBuffer)
  144. assert.NoError(t, err)
  145. }
  146. wg.Done()
  147. }
  148. func TestSecretConnectionReadWrite(t *testing.T) {
  149. fooConn, barConn := makeKVStoreConnPair()
  150. fooWrites, barWrites := []string{}, []string{}
  151. fooReads, barReads := []string{}, []string{}
  152. // Pre-generate the things to write (for foo & bar)
  153. for i := 0; i < 100; i++ {
  154. fooWrites = append(fooWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  155. barWrites = append(barWrites, tmrand.Str((tmrand.Int()%(dataMaxSize*5))+1))
  156. }
  157. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  158. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) async.Task {
  159. return func(_ int) (interface{}, bool, error) {
  160. // Initiate cryptographic private key and secret connection trhough nodeConn.
  161. nodePrvKey := ed25519.GenPrivKey()
  162. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  163. if err != nil {
  164. t.Errorf("failed to establish SecretConnection for node: %v", err)
  165. return nil, true, err
  166. }
  167. // In parallel, handle some reads and writes.
  168. var trs, ok = async.Parallel(
  169. func(_ int) (interface{}, bool, error) {
  170. // Node writes:
  171. for _, nodeWrite := range nodeWrites {
  172. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  173. if err != nil {
  174. t.Errorf("failed to write to nodeSecretConn: %v", err)
  175. return nil, true, err
  176. }
  177. if n != len(nodeWrite) {
  178. err = fmt.Errorf("failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  179. t.Error(err)
  180. return nil, true, err
  181. }
  182. }
  183. if err := nodeConn.PipeWriter.Close(); err != nil {
  184. t.Error(err)
  185. return nil, true, err
  186. }
  187. return nil, false, nil
  188. },
  189. func(_ int) (interface{}, bool, error) {
  190. // Node reads:
  191. readBuffer := make([]byte, dataMaxSize)
  192. for {
  193. n, err := nodeSecretConn.Read(readBuffer)
  194. if err == io.EOF {
  195. if err := nodeConn.PipeReader.Close(); err != nil {
  196. t.Error(err)
  197. return nil, true, err
  198. }
  199. return nil, false, nil
  200. } else if err != nil {
  201. t.Errorf("failed to read from nodeSecretConn: %v", err)
  202. return nil, true, err
  203. }
  204. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  205. }
  206. },
  207. )
  208. assert.True(t, ok, "Unexpected task abortion")
  209. // If error:
  210. if trs.FirstError() != nil {
  211. return nil, true, trs.FirstError()
  212. }
  213. // Otherwise:
  214. return nil, false, nil
  215. }
  216. }
  217. // Run foo & bar in parallel
  218. var trs, ok = async.Parallel(
  219. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  220. genNodeRunner("bar", barConn, barWrites, &barReads),
  221. )
  222. require.Nil(t, trs.FirstError())
  223. require.True(t, ok, "unexpected task abortion")
  224. // A helper to ensure that the writes and reads match.
  225. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  226. compareWritesReads := func(writes []string, reads []string) {
  227. for {
  228. // Pop next write & corresponding reads
  229. var read, write string = "", writes[0]
  230. var readCount = 0
  231. for _, readChunk := range reads {
  232. read += readChunk
  233. readCount++
  234. if len(write) <= len(read) {
  235. break
  236. }
  237. if len(write) <= dataMaxSize {
  238. break // atomicity of small writes
  239. }
  240. }
  241. // Compare
  242. if write != read {
  243. t.Errorf("expected to read %X, got %X", write, read)
  244. }
  245. // Iterate
  246. writes = writes[1:]
  247. reads = reads[readCount:]
  248. if len(writes) == 0 {
  249. break
  250. }
  251. }
  252. }
  253. compareWritesReads(fooWrites, barReads)
  254. compareWritesReads(barWrites, fooReads)
  255. }
  256. // Run go test -update from within this module
  257. // to update the golden test vector file
  258. var update = flag.Bool("update", false, "update .golden files")
  259. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  260. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  261. if *update {
  262. t.Logf("Updating golden test vector file %s", goldenFilepath)
  263. data := createGoldenTestVectors(t)
  264. tmos.WriteFile(goldenFilepath, []byte(data), 0644)
  265. }
  266. f, err := os.Open(goldenFilepath)
  267. if err != nil {
  268. log.Fatal(err)
  269. }
  270. defer f.Close()
  271. scanner := bufio.NewScanner(f)
  272. for scanner.Scan() {
  273. line := scanner.Text()
  274. params := strings.Split(line, ",")
  275. randSecretVector, err := hex.DecodeString(params[0])
  276. require.Nil(t, err)
  277. randSecret := new([32]byte)
  278. copy((*randSecret)[:], randSecretVector)
  279. locIsLeast, err := strconv.ParseBool(params[1])
  280. require.Nil(t, err)
  281. expectedRecvSecret, err := hex.DecodeString(params[2])
  282. require.Nil(t, err)
  283. expectedSendSecret, err := hex.DecodeString(params[3])
  284. require.Nil(t, err)
  285. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  286. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  287. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  288. }
  289. }
  290. type privKeyWithNilPubKey struct {
  291. orig crypto.PrivKey
  292. }
  293. func (pk privKeyWithNilPubKey) Bytes() []byte { return pk.orig.Bytes() }
  294. func (pk privKeyWithNilPubKey) Sign(msg []byte) ([]byte, error) { return pk.orig.Sign(msg) }
  295. func (pk privKeyWithNilPubKey) PubKey() crypto.PubKey { return nil }
  296. func (pk privKeyWithNilPubKey) Equals(pk2 crypto.PrivKey) bool { return pk.orig.Equals(pk2) }
  297. func TestNilPubkey(t *testing.T) {
  298. var fooConn, barConn = makeKVStoreConnPair()
  299. var fooPrvKey = ed25519.GenPrivKey()
  300. var barPrvKey = privKeyWithNilPubKey{ed25519.GenPrivKey()}
  301. go func() {
  302. _, err := MakeSecretConnection(barConn, barPrvKey)
  303. assert.NoError(t, err)
  304. }()
  305. assert.NotPanics(t, func() {
  306. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  307. if assert.Error(t, err) {
  308. assert.Equal(t, "expected ed25519 pubkey, got <nil>", err.Error())
  309. }
  310. })
  311. }
  312. func TestNonEd25519Pubkey(t *testing.T) {
  313. var fooConn, barConn = makeKVStoreConnPair()
  314. var fooPrvKey = ed25519.GenPrivKey()
  315. var barPrvKey = secp256k1.GenPrivKey()
  316. go func() {
  317. _, err := MakeSecretConnection(barConn, barPrvKey)
  318. assert.NoError(t, err)
  319. }()
  320. assert.NotPanics(t, func() {
  321. _, err := MakeSecretConnection(fooConn, fooPrvKey)
  322. if assert.Error(t, err) {
  323. assert.Equal(t, "expected ed25519 pubkey, got secp256k1.PubKeySecp256k1", err.Error())
  324. }
  325. })
  326. }
  327. // Creates the data for a test vector file.
  328. // The file format is:
  329. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  330. func createGoldenTestVectors(t *testing.T) string {
  331. data := ""
  332. for i := 0; i < 32; i++ {
  333. randSecretVector := tmrand.Bytes(32)
  334. randSecret := new([32]byte)
  335. copy((*randSecret)[:], randSecretVector)
  336. data += hex.EncodeToString((*randSecret)[:]) + ","
  337. locIsLeast := tmrand.Bool()
  338. data += strconv.FormatBool(locIsLeast) + ","
  339. recvSecret, sendSecret := deriveSecrets(randSecret, locIsLeast)
  340. data += hex.EncodeToString((*recvSecret)[:]) + ","
  341. data += hex.EncodeToString((*sendSecret)[:]) + ","
  342. }
  343. return data
  344. }
  345. func BenchmarkWriteSecretConnection(b *testing.B) {
  346. b.StopTimer()
  347. b.ReportAllocs()
  348. fooSecConn, barSecConn := makeSecretConnPair(b)
  349. randomMsgSizes := []int{
  350. dataMaxSize / 10,
  351. dataMaxSize / 3,
  352. dataMaxSize / 2,
  353. dataMaxSize,
  354. dataMaxSize * 3 / 2,
  355. dataMaxSize * 2,
  356. dataMaxSize * 7 / 2,
  357. }
  358. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  359. for _, size := range randomMsgSizes {
  360. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  361. }
  362. // Consume reads from bar's reader
  363. go func() {
  364. readBuffer := make([]byte, dataMaxSize)
  365. for {
  366. _, err := barSecConn.Read(readBuffer)
  367. if err == io.EOF {
  368. return
  369. } else if err != nil {
  370. b.Errorf("failed to read from barSecConn: %v", err)
  371. return
  372. }
  373. }
  374. }()
  375. b.StartTimer()
  376. for i := 0; i < b.N; i++ {
  377. idx := tmrand.Intn(len(fooWriteBytes))
  378. _, err := fooSecConn.Write(fooWriteBytes[idx])
  379. if err != nil {
  380. b.Errorf("failed to write to fooSecConn: %v", err)
  381. return
  382. }
  383. }
  384. b.StopTimer()
  385. if err := fooSecConn.Close(); err != nil {
  386. b.Error(err)
  387. }
  388. //barSecConn.Close() race condition
  389. }
  390. func BenchmarkReadSecretConnection(b *testing.B) {
  391. b.StopTimer()
  392. b.ReportAllocs()
  393. fooSecConn, barSecConn := makeSecretConnPair(b)
  394. randomMsgSizes := []int{
  395. dataMaxSize / 10,
  396. dataMaxSize / 3,
  397. dataMaxSize / 2,
  398. dataMaxSize,
  399. dataMaxSize * 3 / 2,
  400. dataMaxSize * 2,
  401. dataMaxSize * 7 / 2,
  402. }
  403. fooWriteBytes := make([][]byte, 0, len(randomMsgSizes))
  404. for _, size := range randomMsgSizes {
  405. fooWriteBytes = append(fooWriteBytes, tmrand.Bytes(size))
  406. }
  407. go func() {
  408. for i := 0; i < b.N; i++ {
  409. idx := tmrand.Intn(len(fooWriteBytes))
  410. _, err := fooSecConn.Write(fooWriteBytes[idx])
  411. if err != nil {
  412. b.Errorf("failed to write to fooSecConn: %v, %v,%v", err, i, b.N)
  413. return
  414. }
  415. }
  416. }()
  417. b.StartTimer()
  418. for i := 0; i < b.N; i++ {
  419. readBuffer := make([]byte, dataMaxSize)
  420. _, err := barSecConn.Read(readBuffer)
  421. if err == io.EOF {
  422. return
  423. } else if err != nil {
  424. b.Fatalf("Failed to read from barSecConn: %v", err)
  425. }
  426. }
  427. b.StopTimer()
  428. }