Browse Source

p2p: add MemoryTransport, an in-memory transport for testing (#5827)

pull/5833/head
Erik Grinaker 4 years ago
committed by GitHub
parent
commit
84ff991387
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 515 additions and 0 deletions
  1. +394
    -0
      p2p/transport_memory.go
  2. +121
    -0
      p2p/transport_memory_test.go

+ 394
- 0
p2p/transport_memory.go View File

@ -0,0 +1,394 @@
package p2p
import (
"context"
"errors"
"fmt"
"io"
"sync"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/crypto/ed25519"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/p2p/conn"
)
const (
MemoryProtocol Protocol = "memory"
)
// MemoryNetwork is an in-memory "network" that uses Go channels to communicate
// between endpoints. Transport endpoints are created with CreateTransport. It
// is primarily used for testing.
type MemoryNetwork struct {
logger log.Logger
mtx sync.RWMutex
transports map[ID]*MemoryTransport
}
// NewMemoryNetwork creates a new in-memory network.
func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
return &MemoryNetwork{
logger: logger,
transports: map[ID]*MemoryTransport{},
}
}
// CreateTransport creates a new memory transport and endpoint for the given
// NodeInfo and private key. Use GenerateTransport() to autogenerate a random
// key and node info.
//
// The transport immediately begins listening on the endpoint "memory:<id>", and
// can be accessed by other transports in the same memory network.
func (n *MemoryNetwork) CreateTransport(
nodeInfo NodeInfo,
privKey crypto.PrivKey,
) (*MemoryTransport, error) {
nodeID := nodeInfo.DefaultNodeID
if nodeID == "" {
return nil, errors.New("no node ID")
}
t := newMemoryTransport(n, nodeInfo, privKey)
n.mtx.Lock()
defer n.mtx.Unlock()
if _, ok := n.transports[nodeID]; ok {
return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
}
n.transports[nodeID] = t
return t, nil
}
// GenerateTransport generates a new transport endpoint by generating a random
// private key and node info. The endpoint address can be obtained via
// Transport.Endpoints().
func (n *MemoryNetwork) GenerateTransport() *MemoryTransport {
privKey := ed25519.GenPrivKey()
nodeID := PubKeyToID(privKey.PubKey())
nodeInfo := NodeInfo{
DefaultNodeID: nodeID,
ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID),
}
t, err := n.CreateTransport(nodeInfo, privKey)
if err != nil {
// GenerateTransport is only used for testing, and the likelihood of
// generating a duplicate node ID is very low, so we'll panic.
panic(err)
}
return t
}
// GetTransport looks up a transport in the network, returning nil if not found.
func (n *MemoryNetwork) GetTransport(id ID) *MemoryTransport {
n.mtx.RLock()
defer n.mtx.RUnlock()
return n.transports[id]
}
// RemoveTransport removes a transport from the network and closes it.
func (n *MemoryNetwork) RemoveTransport(id ID) error {
n.mtx.Lock()
t, ok := n.transports[id]
delete(n.transports, id)
n.mtx.Unlock()
if ok {
// Close may recursively call RemoveTransport() again, but this is safe
// because we've already removed the transport from the map above.
return t.Close()
}
return nil
}
// MemoryTransport is an in-memory transport that's primarily meant for testing.
// It communicates between endpoints using Go channels. To dial a different
// endpoint, both endpoints/transports must be in the same MemoryNetwork.
type MemoryTransport struct {
network *MemoryNetwork
nodeInfo NodeInfo
privKey crypto.PrivKey
logger log.Logger
acceptCh chan *MemoryConnection
closeCh chan struct{}
closeOnce sync.Once
}
// newMemoryTransport creates a new in-memory transport in the given network.
// Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
// to create transports, this is for internal use by MemoryNetwork.
func newMemoryTransport(
network *MemoryNetwork,
nodeInfo NodeInfo,
privKey crypto.PrivKey,
) *MemoryTransport {
return &MemoryTransport{
network: network,
nodeInfo: nodeInfo,
privKey: privKey,
logger: network.logger.With("local",
fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.DefaultNodeID)),
acceptCh: make(chan *MemoryConnection),
closeCh: make(chan struct{}),
}
}
// Accept implements Transport.
func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
select {
case conn := <-t.acceptCh:
t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
return conn, nil
case <-t.closeCh:
return nil, ErrTransportClosed{}
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Dial implements Transport.
func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
if endpoint.Protocol != MemoryProtocol {
return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
}
if endpoint.Path == "" {
return nil, errors.New("no path")
}
if endpoint.PeerID == "" {
return nil, errors.New("no peer ID")
}
t.logger.Info("dialing peer", "remote", endpoint)
peerTransport := t.network.GetTransport(ID(endpoint.Path))
if peerTransport == nil {
return nil, fmt.Errorf("unknown peer %q", endpoint.Path)
}
inCh := make(chan memoryMessage, 1)
outCh := make(chan memoryMessage, 1)
closeCh := make(chan struct{})
closeOnce := sync.Once{}
closer := func() bool {
closed := false
closeOnce.Do(func() {
close(closeCh)
closed = true
})
return closed
}
outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer)
inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer)
select {
case peerTransport.acceptCh <- inConn:
return outConn, nil
case <-peerTransport.closeCh:
return nil, ErrTransportClosed{}
case <-ctx.Done():
return nil, ctx.Err()
}
}
// DialAccept is a convenience function that dials a peer MemoryTransport and
// returns both ends of the connection (A to B and B to A).
func (t *MemoryTransport) DialAccept(
ctx context.Context,
peer *MemoryTransport,
) (Connection, Connection, error) {
endpoints := peer.Endpoints()
if len(endpoints) == 0 {
return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.DefaultNodeID)
}
acceptCh := make(chan Connection, 1)
errCh := make(chan error, 1)
go func() {
conn, err := peer.Accept(ctx)
errCh <- err
acceptCh <- conn
}()
outConn, err := t.Dial(ctx, endpoints[0])
if err != nil {
return nil, nil, err
}
if err = <-errCh; err != nil {
return nil, nil, err
}
inConn := <-acceptCh
return outConn, inConn, nil
}
// Close implements Transport.
func (t *MemoryTransport) Close() error {
err := t.network.RemoveTransport(t.nodeInfo.DefaultNodeID)
t.closeOnce.Do(func() {
close(t.closeCh)
})
t.logger.Info("stopped accepting connections")
return err
}
// Endpoints implements Transport.
func (t *MemoryTransport) Endpoints() []Endpoint {
select {
case <-t.closeCh:
return []Endpoint{}
default:
return []Endpoint{{
Protocol: MemoryProtocol,
PeerID: t.nodeInfo.DefaultNodeID,
Path: string(t.nodeInfo.DefaultNodeID),
}}
}
}
// SetChannelDescriptors implements Transport.
func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) {
}
// MemoryConnection is an in-memory connection between two transports (nodes).
type MemoryConnection struct {
logger log.Logger
local *MemoryTransport
remote *MemoryTransport
receiveCh <-chan memoryMessage
sendCh chan<- memoryMessage
closeCh <-chan struct{}
close func() bool
}
// memoryMessage is used to pass messages internally in the connection.
type memoryMessage struct {
channel byte
message []byte
}
// newMemoryConnection creates a new MemoryConnection. It takes all channels
// (including the closeCh signal channel) on construction, such that they can be
// shared between both ends of the connection.
func newMemoryConnection(
local *MemoryTransport,
remote *MemoryTransport,
receiveCh <-chan memoryMessage,
sendCh chan<- memoryMessage,
closeCh <-chan struct{},
close func() bool,
) *MemoryConnection {
c := &MemoryConnection{
local: local,
remote: remote,
receiveCh: receiveCh,
sendCh: sendCh,
closeCh: closeCh,
close: close,
}
c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
return c
}
// ReceiveMessage implements Connection.
func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
// check close first, since channels are buffered
select {
case <-c.closeCh:
return 0, nil, io.EOF
default:
}
select {
case msg := <-c.receiveCh:
c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
return msg.channel, msg.message, nil
case <-c.closeCh:
return 0, nil, io.EOF
}
}
// SendMessage implements Connection.
func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
// check close first, since channels are buffered
select {
case <-c.closeCh:
return false, io.EOF
default:
}
select {
case c.sendCh <- memoryMessage{channel: chID, message: msg}:
c.logger.Debug("sent message", "channel", chID, "message", msg)
return true, nil
case <-c.closeCh:
return false, io.EOF
}
}
// TrySendMessage implements Connection.
func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
// check close first, since channels are buffered
select {
case <-c.closeCh:
return false, io.EOF
default:
}
select {
case c.sendCh <- memoryMessage{channel: chID, message: msg}:
c.logger.Debug("sent message", "channel", chID, "message", msg)
return true, nil
case <-c.closeCh:
return false, io.EOF
default:
return false, nil
}
}
// Close closes the connection.
func (c *MemoryConnection) Close() error {
if c.close() {
c.logger.Info("closed connection")
}
return nil
}
// FlushClose flushes all pending sends and then closes the connection.
func (c *MemoryConnection) FlushClose() error {
return c.Close()
}
// LocalEndpoint returns the local endpoint for the connection.
func (c *MemoryConnection) LocalEndpoint() Endpoint {
return Endpoint{
PeerID: c.local.nodeInfo.DefaultNodeID,
Protocol: MemoryProtocol,
Path: string(c.local.nodeInfo.DefaultNodeID),
}
}
// RemoteEndpoint returns the remote endpoint for the connection.
func (c *MemoryConnection) RemoteEndpoint() Endpoint {
return Endpoint{
PeerID: c.remote.nodeInfo.DefaultNodeID,
Protocol: MemoryProtocol,
Path: string(c.remote.nodeInfo.DefaultNodeID),
}
}
// PubKey returns the remote peer's public key.
func (c *MemoryConnection) PubKey() crypto.PubKey {
return c.remote.privKey.PubKey()
}
// NodeInfo returns the remote peer's node info.
func (c *MemoryConnection) NodeInfo() NodeInfo {
return c.remote.nodeInfo
}
// Status returns the current connection status.
func (c *MemoryConnection) Status() conn.ConnectionStatus {
return conn.ConnectionStatus{}
}

+ 121
- 0
p2p/transport_memory_test.go View File

@ -0,0 +1,121 @@
package p2p_test
import (
"context"
"io"
"testing"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/p2p"
)
func TestMemoryTransport(t *testing.T) {
ctx := context.Background()
network := p2p.NewMemoryNetwork(log.TestingLogger())
a := network.GenerateTransport()
b := network.GenerateTransport()
c := network.GenerateTransport()
// Dialing a missing endpoint should fail.
_, err := a.Dial(ctx, p2p.Endpoint{
Protocol: p2p.MemoryProtocol,
PeerID: p2p.ID("foo"),
Path: "foo",
})
require.Error(t, err)
// Dialing and accepting a→b and a→c should work.
aToB, bToA, err := a.DialAccept(ctx, b)
require.NoError(t, err)
defer aToB.Close()
defer bToA.Close()
aToC, cToA, err := a.DialAccept(ctx, c)
require.NoError(t, err)
defer aToC.Close()
defer cToA.Close()
// Send and receive a message both ways a→b and b→a
sent, err := aToB.SendMessage(1, []byte("hi!"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err := bToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "hi!", msg)
sent, err = bToA.SendMessage(1, []byte("hello"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = aToB.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "hello", msg)
// Send and receive a message both ways a→c and c→a
sent, err = aToC.SendMessage(1, []byte("foo"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "foo", msg)
sent, err = cToA.SendMessage(1, []byte("bar"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = aToC.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "bar", msg)
// If we close aToB, sending and receiving on either end will error.
err = aToB.Close()
require.NoError(t, err)
_, err = aToB.SendMessage(1, []byte("foo"))
require.Equal(t, io.EOF, err)
_, _, err = aToB.ReceiveMessage()
require.Equal(t, io.EOF, err)
_, err = bToA.SendMessage(1, []byte("foo"))
require.Equal(t, io.EOF, err)
_, _, err = bToA.ReceiveMessage()
require.Equal(t, io.EOF, err)
// We can still send aToC.
sent, err = aToC.SendMessage(1, []byte("foo"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "foo", msg)
// If we close the c transport, it will no longer accept connections,
// but we can still use the open connection.
endpoint := c.Endpoints()[0]
err = c.Close()
require.NoError(t, err)
require.Empty(t, c.Endpoints())
_, err = a.Dial(ctx, endpoint)
require.Error(t, err)
sent, err = aToC.SendMessage(1, []byte("bar"))
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, "bar", msg)
}

Loading…
Cancel
Save