diff --git a/p2p/secret_connection.go b/p2p/secret_connection.go index 8c49a7525..8d1b8e44f 100644 --- a/p2p/secret_connection.go +++ b/p2p/secret_connection.go @@ -26,7 +26,7 @@ const totalFrameSize = dataMaxSize + dataLenSize const sealedFrameSize = totalFrameSize + secretbox.Overhead type SecretConnection struct { - conn io.ReadWriter + conn io.ReadWriteCloser recvBuffer []byte recvNonce *[24]byte sendNonce *[24]byte @@ -37,7 +37,7 @@ type SecretConnection struct { // Performs handshake and returns a new authenticated SecretConnection. // Returns nil if error in handshake. // Caller should call conn.Close() -func MakeSecretConnection(conn io.ReadWriter, locPrivKey acm.PrivKeyEd25519) (*SecretConnection, error) { +func MakeSecretConnection(conn io.ReadWriteCloser, locPrivKey acm.PrivKeyEd25519) (*SecretConnection, error) { locPubKey := locPrivKey.PubKey().(acm.PubKeyEd25519) @@ -157,6 +157,10 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) { return } +func (sc *SecretConnection) Close() error { + return sc.conn.Close() +} + func genEphKeys() (ephPub, ephPriv *[32]byte) { var err error ephPub, ephPriv, err = box.GenerateKey(crand.Reader) @@ -166,7 +170,7 @@ func genEphKeys() (ephPub, ephPriv *[32]byte) { return } -func shareEphPubKey(conn io.ReadWriter, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { +func shareEphPubKey(conn io.ReadWriteCloser, locEphPub *[32]byte) (remEphPub *[32]byte, err error) { var err1, err2 error var wg sync.WaitGroup wg.Add(2) diff --git a/p2p/secret_connection_test.go b/p2p/secret_connection_test.go index e159beba4..dbf289c26 100644 --- a/p2p/secret_connection_test.go +++ b/p2p/secret_connection_test.go @@ -2,7 +2,6 @@ package p2p import ( "bytes" - "fmt" "io" "testing" @@ -10,131 +9,188 @@ import ( . "github.com/tendermint/tendermint/common" ) -type dummyReadWriter struct { - io.Reader - io.Writer +type dummyConn struct { + *io.PipeReader + *io.PipeWriter } -// Each returned ReadWriter is akin to a net.Connection -func makeReadWriterPair() (foo, bar io.ReadWriter) { +func (drw dummyConn) Close() (err error) { + err2 := drw.PipeWriter.CloseWithError(io.EOF) + err1 := drw.PipeReader.Close() + if err2 != nil { + return err + } + return err1 +} + +// Each returned ReadWriteCloser is akin to a net.Connection +func makeDummyConnPair() (fooConn, barConn dummyConn) { barReader, fooWriter := io.Pipe() fooReader, barWriter := io.Pipe() - return dummyReadWriter{fooReader, fooWriter}, dummyReadWriter{barReader, barWriter} + return dummyConn{fooReader, fooWriter}, dummyConn{barReader, barWriter} } -func TestSecretConnectionHandshake(t *testing.T) { - foo, bar := makeReadWriterPair() +func makeSecretConnPair(tb testing.TB) (fooSecConn, barSecConn *SecretConnection) { + fooConn, barConn := makeDummyConnPair() fooPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) fooPubKey := fooPrvKey.PubKey().(acm.PubKeyEd25519) barPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) barPubKey := barPrvKey.PubKey().(acm.PubKeyEd25519) - var fooConn, barConn *SecretConnection Parallel(func() { var err error - fooConn, err = MakeSecretConnection(foo, fooPrvKey) + fooSecConn, err = MakeSecretConnection(fooConn, fooPrvKey) if err != nil { - t.Errorf("Failed to establish SecretConnection for foo: %v", err) + tb.Errorf("Failed to establish SecretConnection for foo: %v", err) return } - if !bytes.Equal(fooConn.RemotePubKey(), barPubKey) { - t.Errorf("Unexpected fooConn.RemotePubKey. Expected %v, got %v", - barPubKey, fooConn.RemotePubKey()) + if !bytes.Equal(fooSecConn.RemotePubKey(), barPubKey) { + tb.Errorf("Unexpected fooSecConn.RemotePubKey. Expected %v, got %v", + barPubKey, fooSecConn.RemotePubKey()) } }, func() { var err error - barConn, err = MakeSecretConnection(bar, barPrvKey) - if barConn == nil { - t.Errorf("Failed to establish SecretConnection for bar: %v", err) + barSecConn, err = MakeSecretConnection(barConn, barPrvKey) + if barSecConn == nil { + tb.Errorf("Failed to establish SecretConnection for bar: %v", err) return } - if !bytes.Equal(barConn.RemotePubKey(), fooPubKey) { - t.Errorf("Unexpected barConn.RemotePubKey. Expected %v, got %v", - fooPubKey, barConn.RemotePubKey()) + if !bytes.Equal(barSecConn.RemotePubKey(), fooPubKey) { + tb.Errorf("Unexpected barSecConn.RemotePubKey. Expected %v, got %v", + fooPubKey, barSecConn.RemotePubKey()) } }) + + return +} + +func TestSecretConnectionHandshake(t *testing.T) { + fooSecConn, barSecConn := makeSecretConnPair(t) + fooSecConn.Close() + barSecConn.Close() } func TestSecretConnectionReadWrite(t *testing.T) { - foo, bar := makeReadWriterPair() - fooPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) - barPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) + fooConn, barConn := makeDummyConnPair() fooWrites, barWrites := []string{}, []string{} fooReads, barReads := []string{}, []string{} - for i := 0; i < 2; i++ { + // Pre-generate the things to write (for foo & bar) + for i := 0; i < 100; i++ { fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) } - fmt.Println("fooWrotes", fooWrites, "\n") - fmt.Println("barWrotes", barWrites, "\n") - - var fooConn, barConn *SecretConnection - Parallel(func() { - var err error - fooConn, err = MakeSecretConnection(foo, fooPrvKey) - if err != nil { - t.Errorf("Failed to establish SecretConnection for foo: %v", err) - return - } - Parallel(func() { - for _, fooWrite := range fooWrites { - fmt.Println("will write foo") - n, err := fooConn.Write([]byte(fooWrite)) - if err != nil { - t.Errorf("Failed to write to fooConn: %v", err) - return + // A helper that will run with (fooConn, fooWrites, fooReads) and vice versa + genNodeRunner := func(nodeConn dummyConn, nodeWrites []string, nodeReads *[]string) func() { + return func() { + // Node handskae + nodePrvKey := acm.PrivKeyEd25519(CRandBytes(32)) + nodeSecretConn, err := MakeSecretConnection(nodeConn, nodePrvKey) + if err != nil { + t.Errorf("Failed to establish SecretConnection for node: %v", err) + return + } + // In parallel, handle reads and writes + Parallel(func() { + // Node writes + for _, nodeWrite := range nodeWrites { + n, err := nodeSecretConn.Write([]byte(nodeWrite)) + if err != nil { + t.Errorf("Failed to write to nodeSecretConn: %v", err) + return + } + if n != len(nodeWrite) { + t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(nodeWrite), n) + return + } } - if n != len(fooWrite) { - t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(fooWrite), n) - return + nodeConn.PipeWriter.Close() + }, func() { + // Node reads + readBuffer := make([]byte, dataMaxSize) + for { + n, err := nodeSecretConn.Read(readBuffer) + if err == io.EOF { + return + } else if err != nil { + t.Errorf("Failed to read from nodeSecretConn: %v", err) + return + } + *nodeReads = append(*nodeReads, string(readBuffer[:n])) } - } - fmt.Println("Done writing foo") - // TODO close foo - }, func() { - fmt.Println("TODO do foo reads") - }) - }, func() { - var err error - barConn, err = MakeSecretConnection(bar, barPrvKey) - if err != nil { - t.Errorf("Failed to establish SecretConnection for bar: %v", err) - return + nodeConn.PipeReader.Close() + }) } - Parallel(func() { - readBuffer := make([]byte, dataMaxSize) - for { - fmt.Println("will read bar") - n, err := barConn.Read(readBuffer) - if err == io.EOF { - return - } else if err != nil { - t.Errorf("Failed to read from barConn: %v", err) - return + } + + // Run foo & bar in parallel + Parallel( + genNodeRunner(fooConn, fooWrites, &fooReads), + genNodeRunner(barConn, barWrites, &barReads), + ) + + // A helper to ensure that the writes and reads match. + // Additionally, small writes (<= dataMaxSize) must be atomically read. + compareWritesReads := func(writes []string, reads []string) { + for { + // Pop next write & corresponding reads + var read, write string = "", writes[0] + var readCount = 0 + for _, readChunk := range reads { + read += readChunk + readCount += 1 + if len(write) <= len(read) { + break + } + if len(write) <= dataMaxSize { + break // atomicity of small writes } - barReads = append(barReads, string(readBuffer[:n])) } - // XXX This does not get called - fmt.Println("Done reading bar") - }, func() { - fmt.Println("TODO do bar writes") - }) - }) + // Compare + if write != read { + t.Errorf("Expected to read %X, got %X", write, read) + } + // Iterate + writes = writes[1:] + reads = reads[readCount:] + if len(writes) == 0 { + break + } + } + } + + compareWritesReads(fooWrites, barReads) + compareWritesReads(barWrites, fooReads) - fmt.Println("fooWrites", fooWrites) - fmt.Println("barReads", barReads) - fmt.Println("barWrites", barWrites) - fmt.Println("fooReads", fooReads) } func BenchmarkSecretConnection(b *testing.B) { b.StopTimer() - b.StartTimer() + fooSecConn, barSecConn := makeSecretConnPair(b) + fooWriteText := RandStr(dataMaxSize) + // Consume reads from bar's reader + go func() { + readBuffer := make([]byte, dataMaxSize) + for { + _, err := barSecConn.Read(readBuffer) + if err == io.EOF { + return + } else if err != nil { + b.Fatalf("Failed to read from barSecConn: %v", err) + } + } + }() + b.StartTimer() for i := 0; i < b.N; i++ { + _, err := fooSecConn.Write([]byte(fooWriteText)) + if err != nil { + b.Fatalf("Failed to write to fooSecConn: %v", err) + } } - b.StopTimer() + + fooSecConn.Close() + //barSecConn.Close() race condition }