You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

383 lines
10 KiB

  1. package conn
  2. import (
  3. "bufio"
  4. "encoding/hex"
  5. "flag"
  6. "fmt"
  7. "io"
  8. "log"
  9. "net"
  10. "os"
  11. "path/filepath"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "github.com/stretchr/testify/assert"
  17. "github.com/stretchr/testify/require"
  18. "github.com/tendermint/tendermint/crypto/ed25519"
  19. cmn "github.com/tendermint/tendermint/libs/common"
  20. )
  21. type kvstoreConn struct {
  22. *io.PipeReader
  23. *io.PipeWriter
  24. }
  25. func (drw kvstoreConn) Close() (err error) {
  26. err2 := drw.PipeWriter.CloseWithError(io.EOF)
  27. err1 := drw.PipeReader.Close()
  28. if err2 != nil {
  29. return err
  30. }
  31. return err1
  32. }
  33. // Each returned ReadWriteCloser is akin to a net.Connection
  34. func makeKVStoreConnPair() (fooConn, barConn kvstoreConn) {
  35. barReader, fooWriter := io.Pipe()
  36. fooReader, barWriter := io.Pipe()
  37. return kvstoreConn{fooReader, fooWriter}, kvstoreConn{barReader, barWriter}
  38. }
  39. func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) {
  40. var fooConn, barConn = makeKVStoreConnPair()
  41. var fooPrvKey = ed25519.GenPrivKey()
  42. var fooPubKey = fooPrvKey.PubKey()
  43. var barPrvKey = ed25519.GenPrivKey()
  44. var barPubKey = barPrvKey.PubKey()
  45. // Make connections from both sides in parallel.
  46. var trs, ok = cmn.Parallel(
  47. func(_ int) (val interface{}, err error, abort bool) {
  48. fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey)
  49. if err != nil {
  50. tb.Errorf("Failed to establish SecretConnection for foo: %v", err)
  51. return nil, err, true
  52. }
  53. remotePubBytes := fooSecConn.RemotePubKey()
  54. if !remotePubBytes.Equals(barPubKey) {
  55. err = fmt.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v",
  56. barPubKey, fooSecConn.RemotePubKey())
  57. tb.Error(err)
  58. return nil, err, false
  59. }
  60. return nil, nil, false
  61. },
  62. func(_ int) (val interface{}, err error, abort bool) {
  63. barSecConn, err = MakeSecretConnection(barConn, barPrvKey)
  64. if barSecConn == nil {
  65. tb.Errorf("Failed to establish SecretConnection for bar: %v", err)
  66. return nil, err, true
  67. }
  68. remotePubBytes := barSecConn.RemotePubKey()
  69. if !remotePubBytes.Equals(fooPubKey) {
  70. err = fmt.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v",
  71. fooPubKey, barSecConn.RemotePubKey())
  72. tb.Error(err)
  73. return nil, nil, false
  74. }
  75. return nil, nil, false
  76. },
  77. )
  78. require.Nil(tb, trs.FirstError())
  79. require.True(tb, ok, "Unexpected task abortion")
  80. return
  81. }
  82. func TestSecretConnectionHandshake(t *testing.T) {
  83. fooSecConn, barSecConn := makeSecretConnPair(t)
  84. if err := fooSecConn.Close(); err != nil {
  85. t.Error(err)
  86. }
  87. if err := barSecConn.Close(); err != nil {
  88. t.Error(err)
  89. }
  90. }
  91. func TestConcurrentWrite(t *testing.T) {
  92. fooSecConn, barSecConn := makeSecretConnPair(t)
  93. fooWriteText := cmn.RandStr(dataMaxSize)
  94. // write from two routines.
  95. // should be safe from race according to net.Conn:
  96. // https://golang.org/pkg/net/#Conn
  97. n := 100
  98. wg := new(sync.WaitGroup)
  99. wg.Add(3)
  100. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  101. go writeLots(t, wg, fooSecConn, fooWriteText, n)
  102. // Consume reads from bar's reader
  103. readLots(t, wg, barSecConn, n*2)
  104. wg.Wait()
  105. if err := fooSecConn.Close(); err != nil {
  106. t.Error(err)
  107. }
  108. }
  109. func TestConcurrentRead(t *testing.T) {
  110. fooSecConn, barSecConn := makeSecretConnPair(t)
  111. fooWriteText := cmn.RandStr(dataMaxSize)
  112. n := 100
  113. // read from two routines.
  114. // should be safe from race according to net.Conn:
  115. // https://golang.org/pkg/net/#Conn
  116. wg := new(sync.WaitGroup)
  117. wg.Add(3)
  118. go readLots(t, wg, fooSecConn, n/2)
  119. go readLots(t, wg, fooSecConn, n/2)
  120. // write to bar
  121. writeLots(t, wg, barSecConn, fooWriteText, n)
  122. wg.Wait()
  123. if err := fooSecConn.Close(); err != nil {
  124. t.Error(err)
  125. }
  126. }
  127. func writeLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, txt string, n int) {
  128. defer wg.Done()
  129. for i := 0; i < n; i++ {
  130. _, err := conn.Write([]byte(txt))
  131. if err != nil {
  132. t.Fatalf("Failed to write to fooSecConn: %v", err)
  133. }
  134. }
  135. }
  136. func readLots(t *testing.T, wg *sync.WaitGroup, conn net.Conn, n int) {
  137. readBuffer := make([]byte, dataMaxSize)
  138. for i := 0; i < n; i++ {
  139. _, err := conn.Read(readBuffer)
  140. assert.NoError(t, err)
  141. }
  142. wg.Done()
  143. }
  144. func TestSecretConnectionReadWrite(t *testing.T) {
  145. fooConn, barConn := makeKVStoreConnPair()
  146. fooWrites, barWrites := []string{}, []string{}
  147. fooReads, barReads := []string{}, []string{}
  148. // Pre-generate the things to write (for foo & bar)
  149. for i := 0; i < 100; i++ {
  150. fooWrites = append(fooWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  151. barWrites = append(barWrites, cmn.RandStr((cmn.RandInt()%(dataMaxSize*5))+1))
  152. }
  153. // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa
  154. genNodeRunner := func(id string, nodeConn kvstoreConn, nodeWrites []string, nodeReads *[]string) cmn.Task {
  155. return func(_ int) (interface{}, error, bool) {
  156. // Initiate cryptographic private key and secret connection trhough nodeConn.
  157. nodePrvKey := ed25519.GenPrivKey()
  158. nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey)
  159. if err != nil {
  160. t.Errorf("Failed to establish SecretConnection for node: %v", err)
  161. return nil, err, true
  162. }
  163. // In parallel, handle some reads and writes.
  164. var trs, ok = cmn.Parallel(
  165. func(_ int) (interface{}, error, bool) {
  166. // Node writes:
  167. for _, nodeWrite := range nodeWrites {
  168. n, err := nodeSecretConn.Write([]byte(nodeWrite))
  169. if err != nil {
  170. t.Errorf("Failed to write to nodeSecretConn: %v", err)
  171. return nil, err, true
  172. }
  173. if n != len(nodeWrite) {
  174. err = fmt.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n)
  175. t.Error(err)
  176. return nil, err, true
  177. }
  178. }
  179. if err := nodeConn.PipeWriter.Close(); err != nil {
  180. t.Error(err)
  181. return nil, err, true
  182. }
  183. return nil, nil, false
  184. },
  185. func(_ int) (interface{}, error, bool) {
  186. // Node reads:
  187. readBuffer := make([]byte, dataMaxSize)
  188. for {
  189. n, err := nodeSecretConn.Read(readBuffer)
  190. if err == io.EOF {
  191. return nil, nil, false
  192. } else if err != nil {
  193. t.Errorf("Failed to read from nodeSecretConn: %v", err)
  194. return nil, err, true
  195. }
  196. *nodeReads = append(*nodeReads, string(readBuffer[:n]))
  197. }
  198. if err := nodeConn.PipeReader.Close(); err != nil {
  199. t.Error(err)
  200. return nil, err, true
  201. }
  202. return nil, nil, false
  203. },
  204. )
  205. assert.True(t, ok, "Unexpected task abortion")
  206. // If error:
  207. if trs.FirstError() != nil {
  208. return nil, trs.FirstError(), true
  209. }
  210. // Otherwise:
  211. return nil, nil, false
  212. }
  213. }
  214. // Run foo & bar in parallel
  215. var trs, ok = cmn.Parallel(
  216. genNodeRunner("foo", fooConn, fooWrites, &fooReads),
  217. genNodeRunner("bar", barConn, barWrites, &barReads),
  218. )
  219. require.Nil(t, trs.FirstError())
  220. require.True(t, ok, "unexpected task abortion")
  221. // A helper to ensure that the writes and reads match.
  222. // Additionally, small writes (<= dataMaxSize) must be atomically read.
  223. compareWritesReads := func(writes []string, reads []string) {
  224. for {
  225. // Pop next write & corresponding reads
  226. var read, write string = "", writes[0]
  227. var readCount = 0
  228. for _, readChunk := range reads {
  229. read += readChunk
  230. readCount++
  231. if len(write) <= len(read) {
  232. break
  233. }
  234. if len(write) <= dataMaxSize {
  235. break // atomicity of small writes
  236. }
  237. }
  238. // Compare
  239. if write != read {
  240. t.Errorf("Expected to read %X, got %X", write, read)
  241. }
  242. // Iterate
  243. writes = writes[1:]
  244. reads = reads[readCount:]
  245. if len(writes) == 0 {
  246. break
  247. }
  248. }
  249. }
  250. compareWritesReads(fooWrites, barReads)
  251. compareWritesReads(barWrites, fooReads)
  252. }
  253. // Run go test -update from within this module
  254. // to update the golden test vector file
  255. var update = flag.Bool("update", false, "update .golden files")
  256. func TestDeriveSecretsAndChallengeGolden(t *testing.T) {
  257. goldenFilepath := filepath.Join("testdata", t.Name()+".golden")
  258. if *update {
  259. t.Logf("Updating golden test vector file %s", goldenFilepath)
  260. data := createGoldenTestVectors(t)
  261. cmn.WriteFile(goldenFilepath, []byte(data), 0644)
  262. }
  263. f, err := os.Open(goldenFilepath)
  264. if err != nil {
  265. log.Fatal(err)
  266. }
  267. defer f.Close()
  268. scanner := bufio.NewScanner(f)
  269. for scanner.Scan() {
  270. line := scanner.Text()
  271. params := strings.Split(line, ",")
  272. randSecretVector, err := hex.DecodeString(params[0])
  273. require.Nil(t, err)
  274. randSecret := new([32]byte)
  275. copy((*randSecret)[:], randSecretVector)
  276. locIsLeast, err := strconv.ParseBool(params[1])
  277. require.Nil(t, err)
  278. expectedRecvSecret, err := hex.DecodeString(params[2])
  279. require.Nil(t, err)
  280. expectedSendSecret, err := hex.DecodeString(params[3])
  281. require.Nil(t, err)
  282. expectedChallenge, err := hex.DecodeString(params[4])
  283. require.Nil(t, err)
  284. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  285. require.Equal(t, expectedRecvSecret, (*recvSecret)[:], "Recv Secrets aren't equal")
  286. require.Equal(t, expectedSendSecret, (*sendSecret)[:], "Send Secrets aren't equal")
  287. require.Equal(t, expectedChallenge, (*challenge)[:], "challenges aren't equal")
  288. }
  289. }
  290. // Creates the data for a test vector file.
  291. // The file format is:
  292. // Hex(diffie_hellman_secret), loc_is_least, Hex(recvSecret), Hex(sendSecret), Hex(challenge)
  293. func createGoldenTestVectors(t *testing.T) string {
  294. data := ""
  295. for i := 0; i < 32; i++ {
  296. randSecretVector := cmn.RandBytes(32)
  297. randSecret := new([32]byte)
  298. copy((*randSecret)[:], randSecretVector)
  299. data += hex.EncodeToString((*randSecret)[:]) + ","
  300. locIsLeast := cmn.RandBool()
  301. data += strconv.FormatBool(locIsLeast) + ","
  302. recvSecret, sendSecret, challenge := deriveSecretAndChallenge(randSecret, locIsLeast)
  303. data += hex.EncodeToString((*recvSecret)[:]) + ","
  304. data += hex.EncodeToString((*sendSecret)[:]) + ","
  305. data += hex.EncodeToString((*challenge)[:]) + "\n"
  306. }
  307. return data
  308. }
  309. func BenchmarkSecretConnection(b *testing.B) {
  310. b.StopTimer()
  311. fooSecConn, barSecConn := makeSecretConnPair(b)
  312. fooWriteText := cmn.RandStr(dataMaxSize)
  313. // Consume reads from bar's reader
  314. go func() {
  315. readBuffer := make([]byte, dataMaxSize)
  316. for {
  317. _, err := barSecConn.Read(readBuffer)
  318. if err == io.EOF {
  319. return
  320. } else if err != nil {
  321. b.Fatalf("Failed to read from barSecConn: %v", err)
  322. }
  323. }
  324. }()
  325. b.StartTimer()
  326. for i := 0; i < b.N; i++ {
  327. _, err := fooSecConn.Write([]byte(fooWriteText))
  328. if err != nil {
  329. b.Fatalf("Failed to write to fooSecConn: %v", err)
  330. }
  331. }
  332. b.StopTimer()
  333. if err := fooSecConn.Close(); err != nil {
  334. b.Error(err)
  335. }
  336. //barSecConn.Close() race condition
  337. }
  338. func fingerprint(bz []byte) []byte {
  339. const fbsize = 40
  340. if len(bz) < fbsize {
  341. return bz
  342. } else {
  343. return bz[:fbsize]
  344. }
  345. }