diff --git a/common/async.go b/common/async.go new file mode 100644 index 000000000..1d302c344 --- /dev/null +++ b/common/async.go @@ -0,0 +1,15 @@ +package common + +import "sync" + +func Parallel(tasks ...func()) { + var wg sync.WaitGroup + wg.Add(len(tasks)) + for _, task := range tasks { + go func(task func()) { + task() + wg.Done() + }(task) + } + wg.Wait() +} diff --git a/p2p/peer.go b/p2p/peer.go index 7fa140222..f0d3a6758 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -4,7 +4,6 @@ import ( "fmt" "io" "net" - "sync" "sync/atomic" "github.com/tendermint/tendermint/binary" @@ -26,22 +25,16 @@ type Peer struct { // Before creating a peer with newPeer(), perform a handshake on connection. func peerHandshake(conn net.Conn, ourNodeInfo *types.NodeInfo) (*types.NodeInfo, error) { var peerNodeInfo = new(types.NodeInfo) - var wg sync.WaitGroup var err1 error var err2 error - wg.Add(2) - go func() { + Parallel(func() { var n int64 binary.WriteBinary(ourNodeInfo, conn, &n, &err1) - wg.Done() - }() - go func() { + }, func() { var n int64 binary.ReadBinary(peerNodeInfo, conn, &n, &err2) log.Info("Peer handshake", "peerNodeInfo", peerNodeInfo) - wg.Done() - }() - wg.Wait() + }) if err1 != nil { return nil, err1 } diff --git a/p2p/secret_connection.go b/p2p/secret_connection.go index e6c23dd27..8c49a7525 100644 --- a/p2p/secret_connection.go +++ b/p2p/secret_connection.go @@ -7,6 +7,7 @@ import ( "crypto/sha256" "encoding/binary" "errors" + //"fmt" "io" "sync" @@ -109,7 +110,8 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) { // encrypt the frame var sealedFrame = make([]byte, sealedFrameSize) - secretbox.Seal(sealedFrame, frame, sc.sendNonce, sc.shrSecret) + secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret) + // fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret) incr2Nonce(sc.sendNonce) // end encryption @@ -139,7 +141,8 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) { // decrypt the frame var frame = make([]byte, totalFrameSize) - _, ok := secretbox.Open(frame, sealedFrame, sc.recvNonce, sc.shrSecret) + // fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret) + _, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret) if !ok { return n, errors.New("Failed to decrypt SecretConnection") } @@ -216,8 +219,8 @@ func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce recvNonce = nonce1 sendNonce = nonce2 } else { - recvNonce = nonce1 - sendNonce = nonce2 + recvNonce = nonce2 + sendNonce = nonce1 } return } diff --git a/p2p/secret_connection_test.go b/p2p/secret_connection_test.go index 634d8309a..e159beba4 100644 --- a/p2p/secret_connection_test.go +++ b/p2p/secret_connection_test.go @@ -2,8 +2,8 @@ package p2p import ( "bytes" + "fmt" "io" - "sync" "testing" acm "github.com/tendermint/tendermint/account" @@ -30,35 +30,103 @@ func TestSecretConnectionHandshake(t *testing.T) { barPubKey := barPrvKey.PubKey().(acm.PubKeyEd25519) var fooConn, barConn *SecretConnection - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() + Parallel(func() { var err error fooConn, err = MakeSecretConnection(foo, fooPrvKey) if err != nil { t.Errorf("Failed to establish SecretConnection for foo: %v", err) return } - if !bytes.Equal(fooConn.RemotePubKey(), fooPubKey) { - t.Errorf("Unexpected fooConn.RemotePubKey. Expected %X, got %X", - fooPubKey, fooConn.RemotePubKey()) + if !bytes.Equal(fooConn.RemotePubKey(), barPubKey) { + t.Errorf("Unexpected fooConn.RemotePubKey. Expected %v, got %v", + barPubKey, fooConn.RemotePubKey()) } - }() - go func() { - defer wg.Done() + }, func() { var err error barConn, err = MakeSecretConnection(bar, barPrvKey) if barConn == nil { t.Errorf("Failed to establish SecretConnection for bar: %v", err) return } - if !bytes.Equal(barConn.RemotePubKey(), barPubKey) { - t.Errorf("Unexpected barConn.RemotePubKey. Expected %X, got %X", - barPubKey, barConn.RemotePubKey()) + if !bytes.Equal(barConn.RemotePubKey(), fooPubKey) { + t.Errorf("Unexpected barConn.RemotePubKey. Expected %v, got %v", + fooPubKey, barConn.RemotePubKey()) } - }() - wg.Wait() + }) +} + +func TestSecretConnectionReadWrite(t *testing.T) { + foo, bar := makeReadWriterPair() + fooPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) + barPrvKey := acm.PrivKeyEd25519(CRandBytes(32)) + fooWrites, barWrites := []string{}, []string{} + fooReads, barReads := []string{}, []string{} + + for i := 0; i < 2; i++ { + fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) + barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1)) + } + + fmt.Println("fooWrotes", fooWrites, "\n") + fmt.Println("barWrotes", barWrites, "\n") + + var fooConn, barConn *SecretConnection + Parallel(func() { + var err error + fooConn, err = MakeSecretConnection(foo, fooPrvKey) + if err != nil { + t.Errorf("Failed to establish SecretConnection for foo: %v", err) + return + } + Parallel(func() { + for _, fooWrite := range fooWrites { + fmt.Println("will write foo") + n, err := fooConn.Write([]byte(fooWrite)) + if err != nil { + t.Errorf("Failed to write to fooConn: %v", err) + return + } + if n != len(fooWrite) { + t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(fooWrite), n) + return + } + } + fmt.Println("Done writing foo") + // TODO close foo + }, func() { + fmt.Println("TODO do foo reads") + }) + }, func() { + var err error + barConn, err = MakeSecretConnection(bar, barPrvKey) + if err != nil { + t.Errorf("Failed to establish SecretConnection for bar: %v", err) + return + } + Parallel(func() { + readBuffer := make([]byte, dataMaxSize) + for { + fmt.Println("will read bar") + n, err := barConn.Read(readBuffer) + if err == io.EOF { + return + } else if err != nil { + t.Errorf("Failed to read from barConn: %v", err) + return + } + barReads = append(barReads, string(readBuffer[:n])) + } + // XXX This does not get called + fmt.Println("Done reading bar") + }, func() { + fmt.Println("TODO do bar writes") + }) + }) + + fmt.Println("fooWrites", fooWrites) + fmt.Println("barReads", barReads) + fmt.Println("barWrites", barWrites) + fmt.Println("fooReads", fooReads) } func BenchmarkSecretConnection(b *testing.B) {