Browse Source

fixed handshake test; wrote broken read/write test

pull/111/merge
Jae Kwon 10 years ago
parent
commit
4981a5993d
4 changed files with 109 additions and 30 deletions
  1. +15
    -0
      common/async.go
  2. +3
    -10
      p2p/peer.go
  3. +7
    -4
      p2p/secret_connection.go
  4. +84
    -16
      p2p/secret_connection_test.go

+ 15
- 0
common/async.go View File

@ -0,0 +1,15 @@
package common
import "sync"
func Parallel(tasks ...func()) {
var wg sync.WaitGroup
wg.Add(len(tasks))
for _, task := range tasks {
go func(task func()) {
task()
wg.Done()
}(task)
}
wg.Wait()
}

+ 3
- 10
p2p/peer.go View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"sync"
"sync/atomic" "sync/atomic"
"github.com/tendermint/tendermint/binary" "github.com/tendermint/tendermint/binary"
@ -26,22 +25,16 @@ type Peer struct {
// Before creating a peer with newPeer(), perform a handshake on connection. // Before creating a peer with newPeer(), perform a handshake on connection.
func peerHandshake(conn net.Conn, ourNodeInfo *types.NodeInfo) (*types.NodeInfo, error) { func peerHandshake(conn net.Conn, ourNodeInfo *types.NodeInfo) (*types.NodeInfo, error) {
var peerNodeInfo = new(types.NodeInfo) var peerNodeInfo = new(types.NodeInfo)
var wg sync.WaitGroup
var err1 error var err1 error
var err2 error var err2 error
wg.Add(2)
go func() {
Parallel(func() {
var n int64 var n int64
binary.WriteBinary(ourNodeInfo, conn, &n, &err1) binary.WriteBinary(ourNodeInfo, conn, &n, &err1)
wg.Done()
}()
go func() {
}, func() {
var n int64 var n int64
binary.ReadBinary(peerNodeInfo, conn, &n, &err2) binary.ReadBinary(peerNodeInfo, conn, &n, &err2)
log.Info("Peer handshake", "peerNodeInfo", peerNodeInfo) log.Info("Peer handshake", "peerNodeInfo", peerNodeInfo)
wg.Done()
}()
wg.Wait()
})
if err1 != nil { if err1 != nil {
return nil, err1 return nil, err1
} }


+ 7
- 4
p2p/secret_connection.go View File

@ -7,6 +7,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
//"fmt"
"io" "io"
"sync" "sync"
@ -109,7 +110,8 @@ func (sc *SecretConnection) Write(data []byte) (n int, err error) {
// encrypt the frame // encrypt the frame
var sealedFrame = make([]byte, sealedFrameSize) var sealedFrame = make([]byte, sealedFrameSize)
secretbox.Seal(sealedFrame, frame, sc.sendNonce, sc.shrSecret)
secretbox.Seal(sealedFrame[:0], frame, sc.sendNonce, sc.shrSecret)
// fmt.Printf("secretbox.Seal(sealed:%X,sendNonce:%X,shrSecret:%X\n", sealedFrame, sc.sendNonce, sc.shrSecret)
incr2Nonce(sc.sendNonce) incr2Nonce(sc.sendNonce)
// end encryption // end encryption
@ -139,7 +141,8 @@ func (sc *SecretConnection) Read(data []byte) (n int, err error) {
// decrypt the frame // decrypt the frame
var frame = make([]byte, totalFrameSize) var frame = make([]byte, totalFrameSize)
_, ok := secretbox.Open(frame, sealedFrame, sc.recvNonce, sc.shrSecret)
// fmt.Printf("secretbox.Open(sealed:%X,recvNonce:%X,shrSecret:%X\n", sealedFrame, sc.recvNonce, sc.shrSecret)
_, ok := secretbox.Open(frame[:0], sealedFrame, sc.recvNonce, sc.shrSecret)
if !ok { if !ok {
return n, errors.New("Failed to decrypt SecretConnection") return n, errors.New("Failed to decrypt SecretConnection")
} }
@ -216,8 +219,8 @@ func genNonces(loPubKey, hiPubKey *[32]byte, locIsLo bool) (recvNonce, sendNonce
recvNonce = nonce1 recvNonce = nonce1
sendNonce = nonce2 sendNonce = nonce2
} else { } else {
recvNonce = nonce1
sendNonce = nonce2
recvNonce = nonce2
sendNonce = nonce1
} }
return return
} }


+ 84
- 16
p2p/secret_connection_test.go View File

@ -2,8 +2,8 @@ package p2p
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"sync"
"testing" "testing"
acm "github.com/tendermint/tendermint/account" acm "github.com/tendermint/tendermint/account"
@ -30,35 +30,103 @@ func TestSecretConnectionHandshake(t *testing.T) {
barPubKey := barPrvKey.PubKey().(acm.PubKeyEd25519) barPubKey := barPrvKey.PubKey().(acm.PubKeyEd25519)
var fooConn, barConn *SecretConnection var fooConn, barConn *SecretConnection
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
Parallel(func() {
var err error var err error
fooConn, err = MakeSecretConnection(foo, fooPrvKey) fooConn, err = MakeSecretConnection(foo, fooPrvKey)
if err != nil { if err != nil {
t.Errorf("Failed to establish SecretConnection for foo: %v", err) t.Errorf("Failed to establish SecretConnection for foo: %v", err)
return return
} }
if !bytes.Equal(fooConn.RemotePubKey(), fooPubKey) {
t.Errorf("Unexpected fooConn.RemotePubKey. Expected %X, got %X",
fooPubKey, fooConn.RemotePubKey())
if !bytes.Equal(fooConn.RemotePubKey(), barPubKey) {
t.Errorf("Unexpected fooConn.RemotePubKey. Expected %v, got %v",
barPubKey, fooConn.RemotePubKey())
} }
}()
go func() {
defer wg.Done()
}, func() {
var err error var err error
barConn, err = MakeSecretConnection(bar, barPrvKey) barConn, err = MakeSecretConnection(bar, barPrvKey)
if barConn == nil { if barConn == nil {
t.Errorf("Failed to establish SecretConnection for bar: %v", err) t.Errorf("Failed to establish SecretConnection for bar: %v", err)
return return
} }
if !bytes.Equal(barConn.RemotePubKey(), barPubKey) {
t.Errorf("Unexpected barConn.RemotePubKey. Expected %X, got %X",
barPubKey, barConn.RemotePubKey())
if !bytes.Equal(barConn.RemotePubKey(), fooPubKey) {
t.Errorf("Unexpected barConn.RemotePubKey. Expected %v, got %v",
fooPubKey, barConn.RemotePubKey())
} }
}()
wg.Wait()
})
}
func TestSecretConnectionReadWrite(t *testing.T) {
foo, bar := makeReadWriterPair()
fooPrvKey := acm.PrivKeyEd25519(CRandBytes(32))
barPrvKey := acm.PrivKeyEd25519(CRandBytes(32))
fooWrites, barWrites := []string{}, []string{}
fooReads, barReads := []string{}, []string{}
for i := 0; i < 2; i++ {
fooWrites = append(fooWrites, RandStr((RandInt()%(dataMaxSize*5))+1))
barWrites = append(barWrites, RandStr((RandInt()%(dataMaxSize*5))+1))
}
fmt.Println("fooWrotes", fooWrites, "\n")
fmt.Println("barWrotes", barWrites, "\n")
var fooConn, barConn *SecretConnection
Parallel(func() {
var err error
fooConn, err = MakeSecretConnection(foo, fooPrvKey)
if err != nil {
t.Errorf("Failed to establish SecretConnection for foo: %v", err)
return
}
Parallel(func() {
for _, fooWrite := range fooWrites {
fmt.Println("will write foo")
n, err := fooConn.Write([]byte(fooWrite))
if err != nil {
t.Errorf("Failed to write to fooConn: %v", err)
return
}
if n != len(fooWrite) {
t.Errorf("Failed to write all bytes. Expected %v, wrote %v", len(fooWrite), n)
return
}
}
fmt.Println("Done writing foo")
// TODO close foo
}, func() {
fmt.Println("TODO do foo reads")
})
}, func() {
var err error
barConn, err = MakeSecretConnection(bar, barPrvKey)
if err != nil {
t.Errorf("Failed to establish SecretConnection for bar: %v", err)
return
}
Parallel(func() {
readBuffer := make([]byte, dataMaxSize)
for {
fmt.Println("will read bar")
n, err := barConn.Read(readBuffer)
if err == io.EOF {
return
} else if err != nil {
t.Errorf("Failed to read from barConn: %v", err)
return
}
barReads = append(barReads, string(readBuffer[:n]))
}
// XXX This does not get called
fmt.Println("Done reading bar")
}, func() {
fmt.Println("TODO do bar writes")
})
})
fmt.Println("fooWrites", fooWrites)
fmt.Println("barReads", barReads)
fmt.Println("barWrites", barWrites)
fmt.Println("fooReads", fooReads)
} }
func BenchmarkSecretConnection(b *testing.B) { func BenchmarkSecretConnection(b *testing.B) {


Loading…
Cancel
Save