Browse Source

p2p: tighten up and test Transport API (#6020)

This tightens up the new P2P `Transport` API and infrastructure, fixes a bunch of bugs and inconsistencies, and adds tests.
pull/6022/head
Erik Grinaker 3 years ago
committed by GitHub
parent
commit
1f39f808e1
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1247 additions and 455 deletions
  1. +4
    -4
      mempool/reactor_test.go
  2. +2
    -5
      p2p/key.go
  3. +22
    -17
      p2p/peer.go
  4. +8
    -8
      p2p/router.go
  5. +2
    -4
      p2p/router_test.go
  6. +1
    -2
      p2p/switch.go
  7. +1
    -1
      p2p/switch_test.go
  8. +69
    -52
      p2p/transport.go
  9. +131
    -105
      p2p/transport_mconn.go
  10. +208
    -0
      p2p/transport_mconn_test.go
  11. +139
    -145
      p2p/transport_memory.go
  12. +23
    -112
      p2p/transport_memory_test.go
  13. +637
    -0
      p2p/transport_test.go

+ 4
- 4
mempool/reactor_test.go View File

@ -39,7 +39,7 @@ type reactorTestSuite struct {
func setup(t *testing.T, cfg *cfg.MempoolConfig, logger log.Logger, chBuf uint) *reactorTestSuite {
t.Helper()
pID := make([]byte, 16)
pID := make([]byte, 20)
_, err := rng.Read(pID)
require.NoError(t, err)
@ -313,7 +313,7 @@ func TestReactorNoBroadcastToSender(t *testing.T) {
func TestMempoolIDsBasic(t *testing.T) {
ids := newMempoolIDs()
peerID, err := p2p.NewNodeID("00ffaa")
peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899")
require.NoError(t, err)
ids.ReserveForPeer(peerID)
@ -399,7 +399,7 @@ func TestDontExhaustMaxActiveIDs(t *testing.T) {
}
}()
peerID, err := p2p.NewNodeID("00ffaa")
peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899")
require.NoError(t, err)
// ensure the reactor does not panic (i.e. exhaust active IDs)
@ -427,7 +427,7 @@ func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) {
// 0 is already reserved for UnknownPeerID
ids := newMempoolIDs()
peerID, err := p2p.NewNodeID("00ffaa")
peerID, err := p2p.NewNodeID("0011223344556677889900112233445566778899")
require.NoError(t, err)
for i := 0; i < maxActiveIDs-1; i++ {


+ 2
- 5
p2p/key.go View File

@ -22,11 +22,8 @@ type NodeID string
// NewNodeID returns a lowercased (normalized) NodeID.
func NewNodeID(nodeID string) (NodeID, error) {
if _, err := NodeID(nodeID).Bytes(); err != nil {
return NodeID(""), err
}
return NodeID(strings.ToLower(nodeID)), nil
n := NodeID(strings.ToLower(nodeID))
return n, n.Validate()
}
// NodeIDFromPubKey returns the noe ID corresponding to the given PubKey. It's


+ 22
- 17
p2p/peer.go View File

@ -156,21 +156,26 @@ func (a PeerAddress) Validate() error {
// String formats the address as a URL string.
func (a PeerAddress) String() string {
// Handle opaque URLs.
if a.Hostname == "" {
s := fmt.Sprintf("%s:%s", a.Protocol, a.NodeID)
if a.Path != "" {
s += "@" + a.Path
u := url.URL{Scheme: string(a.Protocol)}
if a.NodeID != "" {
u.User = url.User(string(a.NodeID))
}
switch {
case a.Hostname != "":
if a.Port > 0 {
u.Host = net.JoinHostPort(a.Hostname, strconv.Itoa(int(a.Port)))
} else {
u.Host = a.Hostname
}
return s
}
s := fmt.Sprintf("%s://%s@%s", a.Protocol, a.NodeID, a.Hostname)
if a.Port > 0 {
s += ":" + strconv.Itoa(int(a.Port))
u.Path = a.Path
case a.Protocol != "":
u.Opaque = a.Path // e.g. memory:foo
case a.Path != "" && a.Path[0] != '/':
u.Path = "/" + a.Path // e.g. some/path
default:
u.Path = a.Path // e.g. /some/path
}
s += a.Path // We've already normalized the path with appropriate prefix in ParsePeerAddress()
return s
return strings.TrimPrefix(u.String(), "//")
}
// PeerStatus specifies peer statuses.
@ -1475,12 +1480,12 @@ func (p *peer) processMessages() {
p.onError(err)
return
}
reactor, ok := p.reactors[chID]
reactor, ok := p.reactors[byte(chID)]
if !ok {
p.onError(fmt.Errorf("unknown channel %v", chID))
return
}
reactor.Receive(chID, p, msg)
reactor.Receive(byte(chID), p, msg)
}
}
@ -1555,7 +1560,7 @@ func (p *peer) Send(chID byte, msgBytes []byte) bool {
} else if !p.hasChannel(chID) {
return false
}
res, err := p.conn.SendMessage(chID, msgBytes)
res, err := p.conn.SendMessage(ChannelID(chID), msgBytes)
if err == io.EOF {
return false
} else if err != nil {
@ -1580,7 +1585,7 @@ func (p *peer) TrySend(chID byte, msgBytes []byte) bool {
} else if !p.hasChannel(chID) {
return false
}
res, err := p.conn.TrySendMessage(chID, msgBytes)
res, err := p.conn.TrySendMessage(ChannelID(chID), msgBytes)
if err == io.EOF {
return false
} else if err != nil {


+ 8
- 8
p2p/router.go View File

@ -293,10 +293,10 @@ func (r *Router) acceptPeers(transport Transport) {
// FIXME: The old P2P stack supported ABCI-based IP address filtering via
// /p2p/filter/addr/<ip> queries, do we want to implement this here as well?
// Filtering by node ID is probably better.
conn, err := transport.Accept(ctx)
conn, err := transport.Accept()
switch err {
case nil:
case ErrTransportClosed{}, io.EOF, context.Canceled:
case io.EOF:
r.logger.Debug("stopping accept routine", "transport", transport)
return
default:
@ -536,8 +536,8 @@ func (r *Router) receivePeer(peerID NodeID, conn Connection) error {
}
r.channelMtx.RLock()
queue, ok := r.channelQueues[ChannelID(chID)]
messageType := r.channelMessages[ChannelID(chID)]
queue, ok := r.channelQueues[chID]
messageType := r.channelMessages[chID]
r.channelMtx.RUnlock()
if !ok {
r.logger.Error("dropping message for unknown channel", "peer", peerID, "channel", chID)
@ -558,8 +558,7 @@ func (r *Router) receivePeer(peerID NodeID, conn Connection) error {
}
select {
// FIXME: ReceiveMessage() should return ChannelID.
case queue.enqueue() <- Envelope{channelID: ChannelID(chID), From: peerID, Message: msg}:
case queue.enqueue() <- Envelope{channelID: chID, From: peerID, Message: msg}:
r.logger.Debug("received message", "peer", peerID, "message", msg)
case <-queue.closed():
r.logger.Error("channel closed, dropping message", "peer", peerID, "channel", chID)
@ -580,8 +579,7 @@ func (r *Router) sendPeer(peerID NodeID, conn Connection, queue queue) error {
continue
}
// FIXME: SendMessage() should take ChannelID.
_, err = conn.SendMessage(byte(envelope.channelID), bz)
_, err = conn.SendMessage(envelope.channelID, bz)
if err != nil {
return err
}
@ -631,6 +629,8 @@ func (r *Router) OnStart() error {
}
// OnStop implements service.Service.
//
// FIXME: This needs to close transports as well.
func (r *Router) OnStop() {
// Collect all active queues, so we can wait for them to close.
queues := []queue{}


+ 2
- 4
p2p/router_test.go View File

@ -50,8 +50,7 @@ func TestRouter(t *testing.T) {
logger := log.TestingLogger()
network := p2p.NewMemoryNetwork(logger)
nodeInfo, privKey := generateNode()
transport, err := network.CreateTransport(nodeInfo.NodeID)
require.NoError(t, err)
transport := network.CreateTransport(nodeInfo.NodeID)
defer transport.Close()
chID := p2p.ChannelID(1)
@ -62,8 +61,7 @@ func TestRouter(t *testing.T) {
peerManager, err := p2p.NewPeerManager(dbm.NewMemDB(), p2p.PeerManagerOptions{})
require.NoError(t, err)
peerInfo, peerKey := generateNode()
peerTransport, err := network.CreateTransport(peerInfo.NodeID)
require.NoError(t, err)
peerTransport := network.CreateTransport(peerInfo.NodeID)
defer peerTransport.Close()
peerRouter, err := p2p.NewRouter(
logger.With("peerID", i),


+ 1
- 2
p2p/switch.go View File

@ -669,8 +669,7 @@ func (sw *Switch) IsPeerPersistent(na *NetAddress) bool {
func (sw *Switch) acceptRoutine() {
for {
var peerNodeInfo NodeInfo
ctx := context.Background()
c, err := sw.transport.Accept(ctx)
c, err := sw.transport.Accept()
if err == nil {
// NOTE: The legacy MConn transport did handshaking in Accept(),
// which was asynchronous and avoided head-of-line-blocking.


+ 1
- 1
p2p/switch_test.go View File

@ -706,7 +706,7 @@ func (et errorTransport) Protocols() []Protocol {
return []Protocol{"error"}
}
func (et errorTransport) Accept(context.Context) (Connection, error) {
func (et errorTransport) Accept() (Connection, error) {
return nil, et.acceptErr
}
func (errorTransport) Dial(context.Context, Endpoint) (Connection, error) {


+ 69
- 52
p2p/transport.go View File

@ -5,13 +5,14 @@ import (
"errors"
"fmt"
"net"
"strconv"
"github.com/tendermint/tendermint/crypto"
"github.com/tendermint/tendermint/p2p/conn"
)
const (
// defaultProtocol is the default protocol used for PeerAddress when
// a protocol isn't explicitly given as a URL scheme.
defaultProtocol Protocol = MConnProtocol
)
@ -20,69 +21,77 @@ type Protocol string
// Transport is a connection-oriented mechanism for exchanging data with a peer.
type Transport interface {
// Protocols returns the protocols the transport supports, which the
// router uses to pick a transport for a PeerAddress.
// Protocols returns the protocols supported by the transport. The Router
// uses this to pick a transport for an Endpoint.
Protocols() []Protocol
// Accept waits for the next inbound connection on a listening endpoint, or
// returns io.EOF if the transport is closed.
Accept(context.Context) (Connection, error)
// Endpoints returns the local endpoints the transport is listening on, if any.
//
// How to listen is transport-dependent, e.g. MConnTransport uses Listen() while
// MemoryTransport starts listening via MemoryNetwork.CreateTransport().
Endpoints() []Endpoint
// Accept waits for the next inbound connection on a listening endpoint, blocking
// until either a connection is available or the transport is closed. On closure,
// io.EOF is returned and further Accept calls are futile.
Accept() (Connection, error)
// Dial creates an outbound connection to an endpoint.
Dial(context.Context, Endpoint) (Connection, error)
// Endpoints lists endpoints the transport is listening on.
Endpoints() []Endpoint
// Close stops accepting new connections, but does not close active connections.
Close() error
// Stringer is used to display the transport, e.g. in logs.
//
// Without this, the logger may use reflection to access and display
// internal fields -- these are written concurrently, which can trigger the
// race detector or even cause a panic.
// internal fields. These can be written to concurrently, which can trigger
// the race detector or even cause a panic.
fmt.Stringer
}
// Connection represents an established connection between two endpoints.
//
// FIXME: This is a temporary interface while we figure out whether we'll be
// adopting QUIC or not. If we do, this should be a byte-oriented multi-stream
// interface with one goroutine consuming each stream, and the MConnection
// transport either needs protocol changes or a shim. For details, see:
// FIXME: This is a temporary interface for backwards-compatibility with the
// current MConnection-protocol, which is message-oriented. It should be
// migrated to a byte-oriented multi-stream interface instead, which would allow
// e.g. adopting QUIC and making message framing, traffic scheduling, and node
// handshakes a Router concern shared across all transports. However, this
// requires MConnection protocol changes or a shim. For details, see:
// https://github.com/tendermint/spec/pull/227
//
// FIXME: The interface is currently very broad in order to accommodate
// MConnection behavior that the rest of the P2P stack relies on. This should be
// removed once the P2P core is rewritten.
// MConnection behavior that the legacy P2P stack relies on. It should be
// cleaned up when the legacy stack is removed.
type Connection interface {
// Handshake handshakes with the remote peer. It must be called immediately
// after the connection is established, and returns the remote peer's node
// info and public key. The caller is responsible for validation.
// Handshake executes a node handshake with the remote peer. It must be
// called immediately after the connection is established, and returns the
// remote peer's node info and public key. The caller is responsible for
// validation.
//
// FIXME: The handshaking should really be the Router's responsibility, but
// FIXME: The handshake should really be the Router's responsibility, but
// that requires the connection interface to be byte-oriented rather than
// message-oriented (see comment above).
Handshake(context.Context, NodeInfo, crypto.PrivKey) (NodeInfo, crypto.PubKey, error)
// ReceiveMessage returns the next message received on the connection,
// blocking until one is available. io.EOF is returned when closed.
ReceiveMessage() (chID byte, msg []byte, err error)
// blocking until one is available. Returns io.EOF if closed.
ReceiveMessage() (ChannelID, []byte, error)
// SendMessage sends a message on the connection.
// FIXME: For compatibility with the current Peer, it returns an additional
// boolean false if the message timed out waiting to be accepted into the
// send buffer.
SendMessage(chID byte, msg []byte) (bool, error)
// SendMessage sends a message on the connection. Returns io.EOF if closed.
//
// FIXME: For compatibility with the legacy P2P stack, it returns an
// additional boolean false if the message timed out waiting to be accepted
// into the send buffer. This should be removed.
SendMessage(ChannelID, []byte) (bool, error)
// TrySendMessage is a non-blocking version of SendMessage that returns
// immediately if the message buffer is full. It returns true if the message
// was accepted.
//
// FIXME: This is here for backwards-compatibility with the current Peer
// code, and should be removed when possible.
TrySendMessage(chID byte, msg []byte) (bool, error)
// FIXME: This method is here for backwards-compatibility with the legacy
// P2P stack and should be removed.
TrySendMessage(ChannelID, []byte) (bool, error)
// LocalEndpoint returns the local endpoint for the connection.
LocalEndpoint() Endpoint
@ -98,68 +107,76 @@ type Connection interface {
// FIXME: This only exists for backwards-compatibility with the current
// MConnection implementation. There should really be a separate Flush()
// method, but there is no easy way to synchronously flush pending data with
// the current MConnection structure.
// the current MConnection code.
FlushClose() error
// Status returns the current connection status.
// FIXME: Only here for compatibility with the current Peer code.
Status() conn.ConnectionStatus
// Stringer is used to display the connection, e.g. in logs.
//
// Without this, the logger may use reflection to access and display
// internal fields. These can be written to concurrently, which can trigger
// the race detector or even cause a panic.
fmt.Stringer
}
// Endpoint represents a transport connection endpoint, either local or remote.
//
// Endpoints are not necessarily networked (see e.g. MemoryTransport) but all
// networked endpoints must use IP as the underlying transport protocol to allow
// e.g. IP address filtering. Either IP or Path (or both) must be set.
type Endpoint struct {
// Protocol specifies the transport protocol, used by the router to pick a
// transport for an endpoint.
// Protocol specifies the transport protocol.
Protocol Protocol
// Path is an optional, arbitrary transport-specific path or identifier.
Path string
// IP is an IP address (v4 or v6) to connect to. If set, this defines the
// endpoint as a networked endpoint.
IP net.IP
// Port is a network port (either TCP or UDP). If not set, a default port
// may be used depending on the protocol.
// Port is a network port (either TCP or UDP). If 0, a default port may be
// used depending on the protocol.
Port uint16
// Path is an optional transport-specific path or identifier.
Path string
}
// PeerAddress converts the endpoint into a peer address for a given node ID.
// PeerAddress converts the endpoint into a PeerAddress for the given node ID.
func (e Endpoint) PeerAddress(nodeID NodeID) PeerAddress {
address := PeerAddress{
NodeID: nodeID,
Protocol: e.Protocol,
Path: e.Path,
}
if e.IP != nil {
if len(e.IP) > 0 {
address.Hostname = e.IP.String()
address.Port = e.Port
}
return address
}
// String formats an endpoint as a URL string.
// String formats the endpoint as a URL string.
func (e Endpoint) String() string {
if e.IP == nil {
return fmt.Sprintf("%s:%s", e.Protocol, e.Path)
}
s := fmt.Sprintf("%s://%s", e.Protocol, e.IP)
if e.Port > 0 {
s += strconv.Itoa(int(e.Port))
}
s += e.Path
return s
return e.PeerAddress("").String()
}
// Validate validates an endpoint.
// Validate validates the endpoint.
func (e Endpoint) Validate() error {
switch {
case e.Protocol == "":
return errors.New("endpoint has no protocol")
case len(e.IP) > 0 && e.IP.To16() == nil:
return fmt.Errorf("invalid IP address %v", e.IP)
case e.Port > 0 && len(e.IP) == 0:
return fmt.Errorf("endpoint has port %v but no IP", e.Port)
case len(e.IP) == 0 && e.Path == "":
return errors.New("endpoint has neither path nor IP")
default:
return nil
}


+ 131
- 105
p2p/transport_mconn.go View File

@ -5,9 +5,10 @@ import (
"errors"
"fmt"
"io"
"math"
"net"
"strconv"
"sync"
"time"
"golang.org/x/net/netutil"
@ -76,51 +77,70 @@ func (m *MConnTransport) Protocols() []Protocol {
return []Protocol{MConnProtocol, TCPProtocol}
}
// Endpoints implements Transport.
func (m *MConnTransport) Endpoints() []Endpoint {
if m.listener == nil {
return []Endpoint{}
}
select {
case <-m.closeCh:
return []Endpoint{}
default:
}
endpoint := Endpoint{
Protocol: MConnProtocol,
}
if addr, ok := m.listener.Addr().(*net.TCPAddr); ok {
endpoint.IP = addr.IP
endpoint.Port = uint16(addr.Port)
}
return []Endpoint{endpoint}
}
// Listen asynchronously listens for inbound connections on the given endpoint.
// It must be called exactly once before calling Accept(), and the caller must
// call Close() to shut down the listener.
//
// FIXME: Listen currently only supports listening on a single endpoint, it
// might be useful to support listening on multiple addresses (e.g. IPv4 and
// IPv6, or a private and public address) via multiple Listen() calls.
func (m *MConnTransport) Listen(endpoint Endpoint) error {
if m.listener != nil {
return errors.New("transport is already listening")
}
endpoint, err := m.normalizeEndpoint(endpoint)
if err != nil {
return fmt.Errorf("invalid MConn listen endpoint %q: %w", endpoint, err)
if err := m.validateEndpoint(endpoint); err != nil {
return err
}
m.listener, err = net.Listen("tcp", fmt.Sprintf("%v:%v", endpoint.IP, endpoint.Port))
listener, err := net.Listen("tcp", net.JoinHostPort(
endpoint.IP.String(), strconv.Itoa(int(endpoint.Port))))
if err != nil {
return err
}
if m.options.MaxAcceptedConnections > 0 {
m.listener = netutil.LimitListener(m.listener, int(m.options.MaxAcceptedConnections))
// FIXME: This will establish the inbound connection but simply hang it
// until another connection is released. It would probably be better to
// return an error to the remote peer or close the connection. This is
// also a DoS vector since the connection will take up kernel resources.
// This was just carried over from the legacy P2P stack.
listener = netutil.LimitListener(listener, int(m.options.MaxAcceptedConnections))
}
m.listener = listener
return nil
}
// Accept implements Transport.
func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) {
func (m *MConnTransport) Accept() (Connection, error) {
if m.listener == nil {
return nil, errors.New("transport is not listening")
}
if deadline, ok := ctx.Deadline(); ok {
if tcpListener, ok := m.listener.(*net.TCPListener); ok {
// FIXME: This probably needs to have a goroutine that overrides the
// deadline on context cancellation as well.
if err := tcpListener.SetDeadline(deadline); err != nil {
return nil, err
}
}
}
tcpConn, err := m.listener.Accept()
if err != nil {
select {
case <-m.closeCh:
return nil, io.EOF
case <-ctx.Done():
return nil, ctx.Err()
default:
return nil, err
}
@ -131,36 +151,28 @@ func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) {
// Dial implements Transport.
func (m *MConnTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
endpoint, err := m.normalizeEndpoint(endpoint)
if err != nil {
if err := m.validateEndpoint(endpoint); err != nil {
return nil, err
}
if endpoint.Port == 0 {
endpoint.Port = 26657
}
dialer := net.Dialer{}
tcpConn, err := dialer.DialContext(ctx, "tcp",
net.JoinHostPort(endpoint.IP.String(), fmt.Sprintf("%v", endpoint.Port)))
tcpConn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(
endpoint.IP.String(), strconv.Itoa(int(endpoint.Port))))
if err != nil {
return nil, err
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
return nil, err
}
}
return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil
}
// Endpoints implements Transport.
func (m *MConnTransport) Endpoints() []Endpoint {
if m.listener == nil {
return []Endpoint{}
}
endpoint := Endpoint{
Protocol: MConnProtocol,
}
if addr, ok := m.listener.Addr().(*net.TCPAddr); ok {
endpoint.IP = addr.IP
endpoint.Port = uint16(addr.Port)
}
return []Endpoint{endpoint}
}
// Close implements Transport.
func (m *MConnTransport) Close() error {
var err error
@ -173,24 +185,21 @@ func (m *MConnTransport) Close() error {
return err
}
// normalizeEndpoint normalizes and validates an endpoint.
func (m *MConnTransport) normalizeEndpoint(endpoint Endpoint) (Endpoint, error) {
// validateEndpoint validates an endpoint.
func (m *MConnTransport) validateEndpoint(endpoint Endpoint) error {
if err := endpoint.Validate(); err != nil {
return Endpoint{}, err
return err
}
if endpoint.Protocol != MConnProtocol && endpoint.Protocol != TCPProtocol {
return Endpoint{}, fmt.Errorf("unsupported protocol %q", endpoint.Protocol)
return fmt.Errorf("unsupported protocol %q", endpoint.Protocol)
}
if len(endpoint.IP) == 0 {
return Endpoint{}, errors.New("endpoint must have an IP address")
return errors.New("endpoint has no IP address")
}
if endpoint.Path != "" {
return Endpoint{}, fmt.Errorf("endpoint cannot have path (got %q)", endpoint.Path)
return fmt.Errorf("endpoints with path not supported (got %q)", endpoint.Path)
}
if endpoint.Port == 0 {
endpoint.Port = 26657
}
return endpoint, nil
return nil
}
// mConnConnection implements Connection for MConnTransport.
@ -209,7 +218,7 @@ type mConnConnection struct {
// mConnMessage passes MConnection messages through internal channels.
type mConnMessage struct {
channelID byte
channelID ChannelID
payload []byte
}
@ -226,52 +235,72 @@ func newMConnConnection(
mConnConfig: mConnConfig,
channelDescs: channelDescs,
receiveCh: make(chan mConnMessage),
errorCh: make(chan error),
errorCh: make(chan error, 1), // buffered to avoid onError leak
closeCh: make(chan struct{}),
}
}
// Handshake implements Connection.
//
// FIXME: Since the MConnection code panics, we need to recover it and turn it
// into an error. We should remove panics instead.
func (c *mConnConnection) Handshake(
ctx context.Context,
nodeInfo NodeInfo,
privKey crypto.PrivKey,
) (peerInfo NodeInfo, peerKey crypto.PubKey, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v", r)
}
) (NodeInfo, crypto.PubKey, error) {
var (
mconn *conn.MConnection
peerInfo NodeInfo
peerKey crypto.PubKey
errCh = make(chan error, 1)
)
// To handle context cancellation, we need to do the handshake in a
// goroutine and abort the blocking network calls by closing the connection
// when the context is cancelled.
go func() {
// FIXME: Since the MConnection code panics, we need to recover it and turn it
// into an error. We should remove panics instead.
defer func() {
if r := recover(); r != nil {
errCh <- fmt.Errorf("recovered from panic: %v", r)
}
}()
var err error
mconn, peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey)
errCh <- err
}()
peerInfo, peerKey, err = c.handshake(ctx, nodeInfo, privKey)
return
select {
case <-ctx.Done():
_ = c.Close()
return NodeInfo{}, nil, ctx.Err()
case err := <-errCh:
if err != nil {
return NodeInfo{}, nil, err
}
c.mconn = mconn
c.logger = mconn.Logger
if err = c.mconn.Start(); err != nil {
return NodeInfo{}, nil, err
}
return peerInfo, peerKey, nil
}
}
// handshake is a helper for Handshake, simplifying error handling so we can
// keep panic recovery in Handshake. It sets c.mconn.
//
// FIXME: Move this into Handshake() when MConnection no longer panics.
// keep context handling and panic recovery in Handshake. It returns an
// unstarted but handshaked MConnection, to avoid concurrent field writes.
func (c *mConnConnection) handshake(
ctx context.Context,
nodeInfo NodeInfo,
privKey crypto.PrivKey,
) (NodeInfo, crypto.PubKey, error) {
) (*conn.MConnection, NodeInfo, crypto.PubKey, error) {
if c.mconn != nil {
return NodeInfo{}, nil, errors.New("connection is already handshaked")
}
if deadline, ok := ctx.Deadline(); ok {
if err := c.conn.SetDeadline(deadline); err != nil {
return NodeInfo{}, nil, err
}
return nil, NodeInfo{}, nil, errors.New("connection is already handshaked")
}
secretConn, err := conn.MakeSecretConnection(c.conn, privKey)
if err != nil {
return NodeInfo{}, nil, err
return nil, NodeInfo{}, nil, err
}
var pbPeerInfo p2pproto.NodeInfo
@ -286,20 +315,14 @@ func (c *mConnConnection) handshake(
}()
for i := 0; i < cap(errCh); i++ {
if err = <-errCh; err != nil {
return NodeInfo{}, nil, err
return nil, NodeInfo{}, nil, err
}
}
peerInfo, err := NodeInfoFromProto(&pbPeerInfo)
if err != nil {
return NodeInfo{}, nil, err
return nil, NodeInfo{}, nil, err
}
if err = c.conn.SetDeadline(time.Time{}); err != nil {
return NodeInfo{}, nil, err
}
c.logger = c.logger.With("peer", c.RemoteEndpoint().PeerAddress(peerInfo.NodeID))
mconn := conn.NewMConnectionWithConfig(
secretConn,
c.channelDescs,
@ -307,31 +330,29 @@ func (c *mConnConnection) handshake(
c.onError,
c.mConnConfig,
)
mconn.SetLogger(c.logger)
if err = mconn.Start(); err != nil {
return NodeInfo{}, nil, err
}
c.mconn = mconn
mconn.SetLogger(c.logger.With("peer", c.RemoteEndpoint().PeerAddress(peerInfo.NodeID)))
return peerInfo, secretConn.RemotePubKey(), nil
return mconn, peerInfo, secretConn.RemotePubKey(), nil
}
// onReceive is a callback for MConnection received messages.
func (c *mConnConnection) onReceive(channelID byte, payload []byte) {
func (c *mConnConnection) onReceive(chID byte, payload []byte) {
select {
case c.receiveCh <- mConnMessage{channelID: channelID, payload: payload}:
case c.receiveCh <- mConnMessage{channelID: ChannelID(chID), payload: payload}:
case <-c.closeCh:
}
}
// onError is a callback for MConnection errors. The error is passed to errorCh,
// which is only consumed by ReceiveMessage() for parity with the old
// MConnection behavior.
// onError is a callback for MConnection errors. The error is passed via errorCh
// to ReceiveMessage (but not SendMessage, for legacy P2P stack behavior).
func (c *mConnConnection) onError(e interface{}) {
err, ok := e.(error)
if !ok {
err = fmt.Errorf("%v", err)
}
// We have to close the connection here, since MConnection will have stopped
// the service on any errors.
_ = c.Close()
select {
case c.errorCh <- err:
case <-c.closeCh:
@ -339,37 +360,42 @@ func (c *mConnConnection) onError(e interface{}) {
}
// String displays connection information.
// FIXME: This is here for backwards compatibility with existing logging,
// it should probably just return RemoteEndpoint().String(), if anything.
func (c *mConnConnection) String() string {
endpoint := c.RemoteEndpoint()
return fmt.Sprintf("MConn{%v:%v}", endpoint.IP, endpoint.Port)
return c.RemoteEndpoint().String()
}
// SendMessage implements Connection.
func (c *mConnConnection) SendMessage(channelID byte, msg []byte) (bool, error) {
// We don't check errorCh here, to preserve old MConnection behavior.
func (c *mConnConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) {
if chID > math.MaxUint8 {
return false, fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID)
}
select {
case err := <-c.errorCh:
return false, err
case <-c.closeCh:
return false, io.EOF
default:
return c.mconn.Send(channelID, msg), nil
return c.mconn.Send(byte(chID), msg), nil
}
}
// TrySendMessage implements Connection.
func (c *mConnConnection) TrySendMessage(channelID byte, msg []byte) (bool, error) {
// We don't check errorCh here, to preserve old MConnection behavior.
func (c *mConnConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) {
if chID > math.MaxUint8 {
return false, fmt.Errorf("MConnection only supports 1-byte channel IDs (got %v)", chID)
}
select {
case err := <-c.errorCh:
return false, err
case <-c.closeCh:
return false, io.EOF
default:
return c.mconn.TrySend(channelID, msg), nil
return c.mconn.TrySend(byte(chID), msg), nil
}
}
// ReceiveMessage implements Connection.
func (c *mConnConnection) ReceiveMessage() (byte, []byte, error) {
func (c *mConnConnection) ReceiveMessage() (ChannelID, []byte, error) {
select {
case err := <-c.errorCh:
return 0, nil, err
@ -416,7 +442,7 @@ func (c *mConnConnection) Status() conn.ConnectionStatus {
func (c *mConnConnection) Close() error {
var err error
c.closeOnce.Do(func() {
if c.mconn != nil {
if c.mconn != nil && c.mconn.IsRunning() {
err = c.mconn.Stop()
} else {
err = c.conn.Close()
@ -430,7 +456,7 @@ func (c *mConnConnection) Close() error {
func (c *mConnConnection) FlushClose() error {
var err error
c.closeOnce.Do(func() {
if c.mconn != nil {
if c.mconn != nil && c.mconn.IsRunning() {
c.mconn.FlushStop()
} else {
err = c.conn.Close()


+ 208
- 0
p2p/transport_mconn_test.go View File

@ -0,0 +1,208 @@
package p2p_test
import (
"io"
"net"
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/libs/log"
"github.com/tendermint/tendermint/p2p"
"github.com/tendermint/tendermint/p2p/conn"
)
// Transports are mainly tested by common tests in transport_test.go, we
// register a transport factory here to get included in those tests.
func init() {
testTransports["mconn"] = func(t *testing.T) p2p.Transport {
transport := p2p.NewMConnTransport(
log.TestingLogger(),
conn.DefaultMConnConfig(),
[]*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}},
p2p.MConnTransportOptions{},
)
err := transport.Listen(p2p.Endpoint{
Protocol: p2p.MConnProtocol,
IP: net.IPv4(127, 0, 0, 1),
Port: 0, // assign a random port
})
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, transport.Close())
})
return transport
}
}
func TestMConnTransport_AcceptBeforeListen(t *testing.T) {
transport := p2p.NewMConnTransport(
log.TestingLogger(),
conn.DefaultMConnConfig(),
[]*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}},
p2p.MConnTransportOptions{
MaxAcceptedConnections: 2,
},
)
t.Cleanup(func() {
_ = transport.Close()
})
_, err := transport.Accept()
require.Error(t, err)
require.NotEqual(t, io.EOF, err) // io.EOF should be returned after Close()
}
func TestMConnTransport_AcceptMaxAcceptedConnections(t *testing.T) {
transport := p2p.NewMConnTransport(
log.TestingLogger(),
conn.DefaultMConnConfig(),
[]*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}},
p2p.MConnTransportOptions{
MaxAcceptedConnections: 2,
},
)
t.Cleanup(func() {
_ = transport.Close()
})
err := transport.Listen(p2p.Endpoint{
Protocol: p2p.MConnProtocol,
IP: net.IPv4(127, 0, 0, 1),
})
require.NoError(t, err)
require.NotEmpty(t, transport.Endpoints())
endpoint := transport.Endpoints()[0]
// Start a goroutine to just accept any connections.
acceptCh := make(chan p2p.Connection, 10)
go func() {
for {
conn, err := transport.Accept()
if err != nil {
return
}
acceptCh <- conn
}
}()
// The first two connections should be accepted just fine.
dial1, err := transport.Dial(ctx, endpoint)
require.NoError(t, err)
defer dial1.Close()
accept1 := <-acceptCh
defer accept1.Close()
require.Equal(t, dial1.LocalEndpoint(), accept1.RemoteEndpoint())
dial2, err := transport.Dial(ctx, endpoint)
require.NoError(t, err)
defer dial2.Close()
accept2 := <-acceptCh
defer accept2.Close()
require.Equal(t, dial2.LocalEndpoint(), accept2.RemoteEndpoint())
// The third connection will be dialed successfully, but the accept should
// not go through.
dial3, err := transport.Dial(ctx, endpoint)
require.NoError(t, err)
defer dial3.Close()
select {
case <-acceptCh:
require.Fail(t, "unexpected accept")
case <-time.After(time.Second):
}
// However, once either of the other connections are closed, the accept
// goes through.
require.NoError(t, accept1.Close())
accept3 := <-acceptCh
defer accept3.Close()
require.Equal(t, dial3.LocalEndpoint(), accept3.RemoteEndpoint())
}
func TestMConnTransport_Listen(t *testing.T) {
testcases := []struct {
endpoint p2p.Endpoint
ok bool
}{
// Valid v4 and v6 addresses, with mconn and tcp protocols.
{p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero}, true},
{p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4(127, 0, 0, 1)}, true},
{p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6zero}, true},
{p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv6loopback}, true},
{p2p.Endpoint{Protocol: p2p.TCPProtocol, IP: net.IPv4zero}, true},
// Invalid endpoints.
{p2p.Endpoint{}, false},
{p2p.Endpoint{Protocol: p2p.MConnProtocol, Path: "foo"}, false},
{p2p.Endpoint{Protocol: p2p.MConnProtocol, IP: net.IPv4zero, Path: "foo"}, false},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.endpoint.String(), func(t *testing.T) {
t.Cleanup(leaktest.Check(t))
transport := p2p.NewMConnTransport(
log.TestingLogger(),
conn.DefaultMConnConfig(),
[]*p2p.ChannelDescriptor{{ID: byte(chID), Priority: 1}},
p2p.MConnTransportOptions{},
)
t.Cleanup(func() {
_ = transport.Close()
})
// Transport should not listen on any endpoints yet.
require.Empty(t, transport.Endpoints())
// Start listening, and check any expected errors.
err := transport.Listen(tc.endpoint)
if !tc.ok {
require.Error(t, err)
return
}
require.NoError(t, err)
// Start a goroutine to just accept any connections.
go func() {
for {
conn, err := transport.Accept()
if err != nil {
return
}
defer func() {
_ = conn.Close()
}()
}
}()
// Check the endpoint.
endpoints := transport.Endpoints()
require.Len(t, endpoints, 1)
endpoint := endpoints[0]
require.Equal(t, p2p.MConnProtocol, endpoint.Protocol)
if tc.endpoint.IP.IsUnspecified() {
require.True(t, endpoint.IP.IsUnspecified(),
"expected unspecified IP, got %v", endpoint.IP)
} else {
require.True(t, tc.endpoint.IP.Equal(endpoint.IP),
"expected %v, got %v", tc.endpoint.IP, endpoint.IP)
}
require.NotZero(t, endpoint.Port)
require.Empty(t, endpoint.Path)
// Dialing the endpoint should work.
conn, err := transport.Dial(ctx, endpoint)
require.NoError(t, err)
require.NoError(t, conn.Close())
// Trying to listen again should error.
err = transport.Listen(tc.endpoint)
require.Error(t, err)
})
}
}

+ 139
- 145
p2p/transport_memory.go View File

@ -15,11 +15,16 @@ import (
const (
MemoryProtocol Protocol = "memory"
// bufferSize is the channel buffer size of MemoryConnection.
bufferSize = 1
)
// MemoryNetwork is an in-memory "network" that uses Go channels to communicate
// between endpoints. Transport endpoints are created with CreateTransport. It
// is primarily used for testing.
// MemoryNetwork is an in-memory "network" that uses buffered Go channels to
// communicate between endpoints. It is primarily meant for testing.
//
// Network endpoints are allocated via CreateTransport(), which takes a node ID,
// and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
type MemoryNetwork struct {
logger log.Logger
@ -35,19 +40,19 @@ func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
}
}
// CreateTransport creates a new memory transport and endpoint with the given
// node ID. It immediately begins listening on the endpoint "memory:<id>", and
// can be accessed by other transports in the same memory network.
func (n *MemoryNetwork) CreateTransport(nodeID NodeID) (*MemoryTransport, error) {
// CreateTransport creates a new memory transport endpoint with the given node
// ID and immediately begins listening on the address "memory:<id>". It panics
// if the node ID is already in use (which is fine, since this is for tests).
func (n *MemoryNetwork) CreateTransport(nodeID NodeID) *MemoryTransport {
t := newMemoryTransport(n, nodeID)
n.mtx.Lock()
defer n.mtx.Unlock()
if _, ok := n.transports[nodeID]; ok {
return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
}
n.transports[nodeID] = t
return t, nil
return t
}
// GetTransport looks up a transport in the network, returning nil if not found.
@ -58,7 +63,7 @@ func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
}
// RemoveTransport removes a transport from the network and closes it.
func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
func (n *MemoryNetwork) RemoveTransport(id NodeID) {
n.mtx.Lock()
t, ok := n.transports[id]
delete(n.transports, id)
@ -67,39 +72,46 @@ func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
if ok {
// Close may recursively call RemoveTransport() again, but this is safe
// because we've already removed the transport from the map above.
return t.Close()
if err := t.Close(); err != nil {
n.logger.Error("failed to close memory transport", "id", id, "err", err)
}
}
return nil
}
// MemoryTransport is an in-memory transport that's primarily meant for testing.
// It communicates between endpoints using Go channels. To dial a different
// endpoint, both endpoints/transports must be in the same MemoryNetwork.
// Size returns the number of transports in the network.
func (n *MemoryNetwork) Size() int {
return len(n.transports)
}
// MemoryTransport is an in-memory transport that uses buffered Go channels to
// communicate between endpoints. It is primarily meant for testing.
//
// New transports are allocated with MemoryNetwork.CreateTransport(). To contact
// a different endpoint, both transports must be in the same MemoryNetwork.
type MemoryTransport struct {
logger log.Logger
network *MemoryNetwork
nodeID NodeID
logger log.Logger
acceptCh chan *MemoryConnection
closeCh chan struct{}
closeOnce sync.Once
}
// newMemoryTransport creates a new in-memory transport in the given network.
// Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
// to create transports, this is for internal use by MemoryNetwork.
// newMemoryTransport creates a new MemoryTransport. This is for internal use by
// MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
func newMemoryTransport(network *MemoryNetwork, nodeID NodeID) *MemoryTransport {
return &MemoryTransport{
logger: network.logger.With("local", nodeID),
network: network,
nodeID: nodeID,
logger: network.logger.With("local", fmt.Sprintf("%v:%v", MemoryProtocol, nodeID)),
acceptCh: make(chan *MemoryConnection),
closeCh: make(chan struct{}),
}
}
// String displays the transport.
// String implements Transport.
func (t *MemoryTransport) String() string {
return string(MemoryProtocol)
}
@ -109,16 +121,27 @@ func (t *MemoryTransport) Protocols() []Protocol {
return []Protocol{MemoryProtocol}
}
// Endpoints implements Transport.
func (t *MemoryTransport) Endpoints() []Endpoint {
select {
case <-t.closeCh:
return []Endpoint{}
default:
return []Endpoint{{
Protocol: MemoryProtocol,
Path: string(t.nodeID),
}}
}
}
// Accept implements Transport.
func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
func (t *MemoryTransport) Accept() (Connection, error) {
select {
case conn := <-t.acceptCh:
t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
return conn, nil
case <-t.closeCh:
return nil, io.EOF
case <-ctx.Done():
return nil, ctx.Err()
}
}
@ -134,124 +157,104 @@ func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connecti
if err != nil {
return nil, err
}
t.logger.Info("dialing peer", "remote", endpoint)
peerTransport := t.network.GetTransport(nodeID)
if peerTransport == nil {
t.logger.Info("dialing peer", "remote", nodeID)
peer := t.network.GetTransport(nodeID)
if peer == nil {
return nil, fmt.Errorf("unknown peer %q", nodeID)
}
inCh := make(chan memoryMessage, 1)
outCh := make(chan memoryMessage, 1)
inCh := make(chan memoryMessage, bufferSize)
outCh := make(chan memoryMessage, bufferSize)
closer := tmsync.NewCloser()
outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closer)
inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closer)
outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
select {
case peerTransport.acceptCh <- inConn:
case peer.acceptCh <- inConn:
return outConn, nil
case <-peerTransport.closeCh:
return nil, ErrTransportClosed{}
case <-peer.closeCh:
return nil, io.EOF
case <-ctx.Done():
return nil, ctx.Err()
}
}
// DialAccept is a convenience function that dials a peer MemoryTransport and
// returns both ends of the connection (A to B and B to A).
func (t *MemoryTransport) DialAccept(
ctx context.Context,
peer *MemoryTransport,
) (Connection, Connection, error) {
endpoints := peer.Endpoints()
if len(endpoints) == 0 {
return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeID)
}
acceptCh := make(chan Connection, 1)
errCh := make(chan error, 1)
go func() {
conn, err := peer.Accept(ctx)
errCh <- err
acceptCh <- conn
}()
outConn, err := t.Dial(ctx, endpoints[0])
if err != nil {
return nil, nil, err
}
if err = <-errCh; err != nil {
return nil, nil, err
}
inConn := <-acceptCh
return outConn, inConn, nil
}
// Close implements Transport.
func (t *MemoryTransport) Close() error {
err := t.network.RemoveTransport(t.nodeID)
t.network.RemoveTransport(t.nodeID)
t.closeOnce.Do(func() {
close(t.closeCh)
t.logger.Info("closed transport")
})
t.logger.Info("stopped accepting connections")
return err
}
// Endpoints implements Transport.
func (t *MemoryTransport) Endpoints() []Endpoint {
select {
case <-t.closeCh:
return []Endpoint{}
default:
return []Endpoint{{
Protocol: MemoryProtocol,
Path: string(t.nodeID),
}}
}
return nil
}
// MemoryConnection is an in-memory connection between two transports (nodes).
// MemoryConnection is an in-memory connection between two transport endpoints.
type MemoryConnection struct {
logger log.Logger
local *MemoryTransport
remote *MemoryTransport
logger log.Logger
localID NodeID
remoteID NodeID
receiveCh <-chan memoryMessage
sendCh chan<- memoryMessage
closer *tmsync.Closer
}
// memoryMessage is used to pass messages internally in the connection.
// For handshakes, nodeInfo and pubKey are set instead of channel and message.
// memoryMessage is passed internally, containing either a message or handshake.
type memoryMessage struct {
channel byte
message []byte
channelID ChannelID
message []byte
// For handshakes.
nodeInfo NodeInfo
nodeInfo *NodeInfo
pubKey crypto.PubKey
}
// newMemoryConnection creates a new MemoryConnection. It takes all channels
// (including the closeCh signal channel) on construction, such that they can be
// shared between both ends of the connection.
// newMemoryConnection creates a new MemoryConnection.
func newMemoryConnection(
local *MemoryTransport,
remote *MemoryTransport,
logger log.Logger,
localID NodeID,
remoteID NodeID,
receiveCh <-chan memoryMessage,
sendCh chan<- memoryMessage,
closer *tmsync.Closer,
) *MemoryConnection {
c := &MemoryConnection{
local: local,
remote: remote,
return &MemoryConnection{
logger: logger.With("remote", remoteID),
localID: localID,
remoteID: remoteID,
receiveCh: receiveCh,
sendCh: sendCh,
closer: closer,
}
c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
return c
}
// String implements Connection.
func (c *MemoryConnection) String() string {
return c.RemoteEndpoint().String()
}
// LocalEndpoint implements Connection.
func (c *MemoryConnection) LocalEndpoint() Endpoint {
return Endpoint{
Protocol: MemoryProtocol,
Path: string(c.localID),
}
}
// RemoteEndpoint implements Connection.
func (c *MemoryConnection) RemoteEndpoint() Endpoint {
return Endpoint{
Protocol: MemoryProtocol,
Path: string(c.remoteID),
}
}
// Status implements Connection.
func (c *MemoryConnection) Status() conn.ConnectionStatus {
return conn.ConnectionStatus{}
}
// Handshake implements Connection.
@ -261,27 +264,32 @@ func (c *MemoryConnection) Handshake(
privKey crypto.PrivKey,
) (NodeInfo, crypto.PubKey, error) {
select {
case c.sendCh <- memoryMessage{nodeInfo: nodeInfo, pubKey: privKey.PubKey()}:
case <-ctx.Done():
return NodeInfo{}, nil, ctx.Err()
case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
case <-c.closer.Done():
return NodeInfo{}, nil, io.EOF
case <-ctx.Done():
return NodeInfo{}, nil, ctx.Err()
}
select {
case msg := <-c.receiveCh:
c.logger.Debug("handshake complete")
return msg.nodeInfo, msg.pubKey, nil
case <-ctx.Done():
return NodeInfo{}, nil, ctx.Err()
if msg.nodeInfo == nil {
return NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
}
c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
return *msg.nodeInfo, msg.pubKey, nil
case <-c.closer.Done():
return NodeInfo{}, nil, io.EOF
case <-ctx.Done():
return NodeInfo{}, nil, ctx.Err()
}
}
// ReceiveMessage implements Connection.
func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
// check close first, since channels are buffered
func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
// Check close first, since channels are buffered. Otherwise, below select
// may non-deterministically return non-error even when closed.
select {
case <-c.closer.Done():
return 0, nil, io.EOF
@ -290,16 +298,17 @@ func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
select {
case msg := <-c.receiveCh:
c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
return msg.channel, msg.message, nil
c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
return msg.channelID, msg.message, nil
case <-c.closer.Done():
return 0, nil, io.EOF
}
}
// SendMessage implements Connection.
func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
// check close first, since channels are buffered
func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) {
// Check close first, since channels are buffered. Otherwise, below select
// may non-deterministically return non-error even when closed.
select {
case <-c.closer.Done():
return false, io.EOF
@ -307,8 +316,8 @@ func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
}
select {
case c.sendCh <- memoryMessage{channel: chID, message: msg}:
c.logger.Debug("sent message", "channel", chID, "message", msg)
case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
c.logger.Debug("sent message", "chID", chID, "msg", msg)
return true, nil
case <-c.closer.Done():
return false, io.EOF
@ -316,8 +325,9 @@ func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
}
// TrySendMessage implements Connection.
func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
// check close first, since channels are buffered
func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) {
// Check close first, since channels are buffered. Otherwise, below select
// may non-deterministically return non-error even when closed.
select {
case <-c.closer.Done():
return false, io.EOF
@ -325,8 +335,8 @@ func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
}
select {
case c.sendCh <- memoryMessage{channel: chID, message: msg}:
c.logger.Debug("sent message", "channel", chID, "message", msg)
case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
c.logger.Debug("sent message", "chID", chID, "msg", msg)
return true, nil
case <-c.closer.Done():
return false, io.EOF
@ -335,35 +345,19 @@ func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
}
}
// Close closes the connection.
// Close implements Connection.
func (c *MemoryConnection) Close() error {
c.closer.Close()
c.logger.Info("closed connection")
select {
case <-c.closer.Done():
return nil
default:
c.closer.Close()
c.logger.Info("closed connection")
}
return nil
}
// FlushClose flushes all pending sends and then closes the connection.
// FlushClose implements Connection.
func (c *MemoryConnection) FlushClose() error {
return c.Close()
}
// LocalEndpoint returns the local endpoint for the connection.
func (c *MemoryConnection) LocalEndpoint() Endpoint {
return Endpoint{
Protocol: MemoryProtocol,
Path: string(c.local.nodeID),
}
}
// RemoteEndpoint returns the remote endpoint for the connection.
func (c *MemoryConnection) RemoteEndpoint() Endpoint {
return Endpoint{
Protocol: MemoryProtocol,
Path: string(c.remote.nodeID),
}
}
// Status returns the current connection status.
func (c *MemoryConnection) Status() conn.ConnectionStatus {
return conn.ConnectionStatus{}
}

+ 23
- 112
p2p/transport_memory_test.go View File

@ -1,8 +1,8 @@
package p2p_test
import (
"context"
"io"
"bytes"
"encoding/hex"
"testing"
"github.com/stretchr/testify/require"
@ -10,114 +10,25 @@ import (
"github.com/tendermint/tendermint/p2p"
)
func TestMemoryTransport(t *testing.T) {
ctx := context.Background()
network := p2p.NewMemoryNetwork(log.TestingLogger())
a, err := network.CreateTransport("0a")
require.NoError(t, err)
b, err := network.CreateTransport("0b")
require.NoError(t, err)
c, err := network.CreateTransport("0c")
require.NoError(t, err)
// Dialing a missing endpoint should fail.
_, err = a.Dial(ctx, p2p.Endpoint{
Protocol: p2p.MemoryProtocol,
Path: "foo",
})
require.Error(t, err)
// Dialing and accepting a→b and a→c should work.
aToB, bToA, err := a.DialAccept(ctx, b)
require.NoError(t, err)
defer aToB.Close()
defer bToA.Close()
aToC, cToA, err := a.DialAccept(ctx, c)
require.NoError(t, err)
defer aToC.Close()
defer cToA.Close()
// Send and receive a message both ways a→b and b→a
sent, err := aToB.SendMessage(1, []byte{0x01})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err := bToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x01}, msg)
sent, err = bToA.SendMessage(1, []byte{0x02})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = aToB.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x02}, msg)
// Send and receive a message both ways a→c and c→a
sent, err = aToC.SendMessage(1, []byte{0x03})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x03}, msg)
sent, err = cToA.SendMessage(1, []byte{0x04})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = aToC.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x04}, msg)
// If we close aToB, sending and receiving on either end will error.
err = aToB.Close()
require.NoError(t, err)
_, err = aToB.SendMessage(1, []byte{0x05})
require.Equal(t, io.EOF, err)
_, _, err = aToB.ReceiveMessage()
require.Equal(t, io.EOF, err)
_, err = bToA.SendMessage(1, []byte{0x06})
require.Equal(t, io.EOF, err)
_, _, err = bToA.ReceiveMessage()
require.Equal(t, io.EOF, err)
// We can still send aToC.
sent, err = aToC.SendMessage(1, []byte{0x07})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x07}, msg)
// If we close the c transport, it will no longer accept connections,
// but we can still use the open connection.
endpoint := c.Endpoints()[0]
err = c.Close()
require.NoError(t, err)
require.Empty(t, c.Endpoints())
_, err = a.Dial(ctx, endpoint)
require.Error(t, err)
sent, err = aToC.SendMessage(1, []byte{0x08})
require.NoError(t, err)
require.True(t, sent)
ch, msg, err = cToA.ReceiveMessage()
require.NoError(t, err)
require.EqualValues(t, 1, ch)
require.EqualValues(t, []byte{0x08}, msg)
// Transports are mainly tested by common tests in transport_test.go, we
// register a transport factory here to get included in those tests.
func init() {
var network *p2p.MemoryNetwork // shared by transports in the same test
testTransports["memory"] = func(t *testing.T) p2p.Transport {
if network == nil {
network = p2p.NewMemoryNetwork(log.TestingLogger())
}
i := byte(network.Size())
nodeID, err := p2p.NewNodeID(hex.EncodeToString(bytes.Repeat([]byte{i<<4 + i}, 20)))
require.NoError(t, err)
transport := network.CreateTransport(nodeID)
t.Cleanup(func() {
require.NoError(t, transport.Close())
network = nil // set up a new memory network for the next test
})
return transport
}
}

+ 637
- 0
p2p/transport_test.go View File

@ -0,0 +1,637 @@
package p2p_test
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/crypto/ed25519"
"github.com/tendermint/tendermint/libs/bytes"
"github.com/tendermint/tendermint/p2p"
)
// transportFactory is used to set up transports for tests.
type transportFactory func(t *testing.T) p2p.Transport
var (
ctx = context.Background() // convenience context
chID = p2p.ChannelID(1) // channel ID for use in tests
testTransports = map[string]transportFactory{} // registry for withTransports
)
// withTransports is a test helper that runs a test against all transports
// registered in testTransports.
func withTransports(t *testing.T, tester func(*testing.T, transportFactory)) {
t.Helper()
for name, transportFactory := range testTransports {
transportFactory := transportFactory
t.Run(name, func(t *testing.T) {
t.Cleanup(leaktest.Check(t))
tester(t, transportFactory)
})
}
}
func TestTransport_AcceptClose(t *testing.T) {
// Just test accept unblock on close, happy path is tested widely elsewhere.
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
// In-progress Accept should error on concurrent close.
errCh := make(chan error, 1)
go func() {
time.Sleep(200 * time.Millisecond)
errCh <- a.Close()
}()
_, err := a.Accept()
require.Error(t, err)
require.Equal(t, io.EOF, err)
require.NoError(t, <-errCh)
// Closed transport should return error immediately.
_, err = a.Accept()
require.Error(t, err)
require.Equal(t, io.EOF, err)
})
}
func TestTransport_DialEndpoints(t *testing.T) {
ipTestCases := []struct {
ip net.IP
ok bool
}{
{net.IPv4zero, true},
{net.IPv6zero, true},
{nil, false},
{net.IPv4bcast, false},
{net.IPv4allsys, false},
{[]byte{1, 2, 3}, false},
{[]byte{1, 2, 3, 4, 5}, false},
}
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
endpoints := a.Endpoints()
require.NotEmpty(t, endpoints)
endpoint := endpoints[0]
// Spawn a goroutine to simply accept any connections until closed.
go func() {
for {
conn, err := a.Accept()
if err != nil {
return
}
_ = conn.Close()
}
}()
// Dialing self should work.
conn, err := a.Dial(ctx, endpoint)
require.NoError(t, err)
require.NoError(t, conn.Close())
// Dialing empty endpoint should error.
_, err = a.Dial(ctx, p2p.Endpoint{})
require.Error(t, err)
// Dialing without protocol should error.
noProtocol := endpoint
noProtocol.Protocol = ""
_, err = a.Dial(ctx, noProtocol)
require.Error(t, err)
// Dialing with invalid protocol should error.
fooProtocol := endpoint
fooProtocol.Protocol = "foo"
_, err = a.Dial(ctx, fooProtocol)
require.Error(t, err)
// Tests for networked endpoints (with IP).
if len(endpoint.IP) > 0 {
for _, tc := range ipTestCases {
tc := tc
t.Run(tc.ip.String(), func(t *testing.T) {
e := endpoint
e.IP = tc.ip
conn, err := a.Dial(ctx, e)
if tc.ok {
require.NoError(t, conn.Close())
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
// Non-networked endpoints should error.
noIP := endpoint
noIP.IP = nil
noIP.Port = 0
noIP.Path = "foo"
_, err := a.Dial(ctx, noIP)
require.Error(t, err)
} else {
// Tests for non-networked endpoints (no IP).
noPath := endpoint
noPath.Path = ""
_, err = a.Dial(ctx, noPath)
require.Error(t, err)
}
})
}
func TestTransport_Dial(t *testing.T) {
// Most just tests dial failures, happy path is tested widely elsewhere.
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
require.NotEmpty(t, a.Endpoints())
require.NotEmpty(t, b.Endpoints())
aEndpoint := a.Endpoints()[0]
bEndpoint := b.Endpoints()[0]
// Context cancellation should error. We can't test timeouts since we'd
// need a non-responsive endpoint.
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
_, err := a.Dial(cancelCtx, bEndpoint)
require.Error(t, err)
require.Equal(t, err, context.Canceled)
// Unavailable endpoint should error.
err = b.Close()
require.NoError(t, err)
_, err = a.Dial(ctx, bEndpoint)
require.Error(t, err)
// Dialing from a closed transport should still work.
errCh := make(chan error, 1)
go func() {
conn, err := a.Accept()
if err == nil {
_ = conn.Close()
}
errCh <- err
}()
conn, err := b.Dial(ctx, aEndpoint)
require.NoError(t, err)
require.NoError(t, conn.Close())
require.NoError(t, <-errCh)
})
}
func TestTransport_Endpoints(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
// Both transports return valid and different endpoints.
aEndpoints := a.Endpoints()
bEndpoints := b.Endpoints()
require.NotEmpty(t, aEndpoints)
require.NotEmpty(t, bEndpoints)
require.NotEqual(t, aEndpoints, bEndpoints)
for _, endpoint := range append(aEndpoints, bEndpoints...) {
err := endpoint.Validate()
require.NoError(t, err, "invalid endpoint %q", endpoint)
}
// When closed, the transport should no longer return any endpoints.
err := a.Close()
require.NoError(t, err)
require.Empty(t, a.Endpoints())
require.NotEmpty(t, b.Endpoints())
})
}
func TestTransport_Protocols(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
protocols := a.Protocols()
endpoints := a.Endpoints()
require.NotEmpty(t, protocols)
require.NotEmpty(t, endpoints)
for _, endpoint := range endpoints {
require.Contains(t, protocols, endpoint.Protocol)
}
})
}
func TestTransport_String(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
require.NotEmpty(t, a.String())
})
}
func TestConnection_Handshake(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, ba := dialAccept(t, a, b)
// A handshake should pass the given keys and NodeInfo.
aKey := ed25519.GenPrivKey()
aInfo := p2p.NodeInfo{
NodeID: p2p.NodeIDFromPubKey(aKey.PubKey()),
ProtocolVersion: p2p.NewProtocolVersion(1, 2, 3),
ListenAddr: "listenaddr",
Network: "network",
Version: "1.2.3",
Channels: bytes.HexBytes([]byte{0xf0, 0x0f}),
Moniker: "moniker",
Other: p2p.NodeInfoOther{
TxIndex: "txindex",
RPCAddress: "rpc.domain.com",
},
}
bKey := ed25519.GenPrivKey()
bInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(bKey.PubKey())}
errCh := make(chan error, 1)
go func() {
// Must use assert due to goroutine.
peerInfo, peerKey, err := ba.Handshake(ctx, bInfo, bKey)
if err == nil {
assert.Equal(t, aInfo, peerInfo)
assert.Equal(t, aKey.PubKey(), peerKey)
}
errCh <- err
}()
peerInfo, peerKey, err := ab.Handshake(ctx, aInfo, aKey)
require.NoError(t, err)
require.Equal(t, bInfo, peerInfo)
require.Equal(t, bKey.PubKey(), peerKey)
require.NoError(t, <-errCh)
})
}
func TestConnection_HandshakeCancel(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
// Handshake should error on context cancellation.
ab, ba := dialAccept(t, a, b)
timeoutCtx, cancel := context.WithTimeout(ctx, 1*time.Minute)
cancel()
_, _, err := ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey())
require.Error(t, err)
require.Equal(t, context.Canceled, err)
_ = ab.Close()
_ = ba.Close()
// Handshake should error on context timeout.
ab, ba = dialAccept(t, a, b)
timeoutCtx, cancel = context.WithTimeout(ctx, 200*time.Millisecond)
defer cancel()
_, _, err = ab.Handshake(timeoutCtx, p2p.NodeInfo{}, ed25519.GenPrivKey())
require.Error(t, err)
require.Equal(t, context.DeadlineExceeded, err)
_ = ab.Close()
_ = ba.Close()
})
}
func TestConnection_FlushClose(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, _ := dialAcceptHandshake(t, a, b)
// FIXME: FlushClose should be removed (and replaced by separate Flush
// and Close calls if necessary). We can't reliably test it, so we just
// make sure it closes both ends and that it's idempotent.
err := ab.FlushClose()
require.NoError(t, err)
_, _, err = ab.ReceiveMessage()
require.Error(t, err)
require.Equal(t, io.EOF, err)
_, err = ab.SendMessage(chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
err = ab.FlushClose()
require.NoError(t, err)
})
}
func TestConnection_LocalRemoteEndpoint(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, ba := dialAcceptHandshake(t, a, b)
// Local and remote connection endpoints correspond to each other.
require.NotEmpty(t, ab.LocalEndpoint())
require.NotEmpty(t, ba.LocalEndpoint())
require.Equal(t, ab.LocalEndpoint(), ba.RemoteEndpoint())
require.Equal(t, ab.RemoteEndpoint(), ba.LocalEndpoint())
})
}
func TestConnection_SendReceive(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, ba := dialAcceptHandshake(t, a, b)
// Can send and receive a to b.
ok, err := ab.SendMessage(chID, []byte("foo"))
require.NoError(t, err)
require.True(t, ok)
ch, msg, err := ba.ReceiveMessage()
require.NoError(t, err)
require.Equal(t, []byte("foo"), msg)
require.Equal(t, chID, ch)
// Can send and receive b to a.
_, err = ba.SendMessage(chID, []byte("bar"))
require.NoError(t, err)
_, msg, err = ab.ReceiveMessage()
require.NoError(t, err)
require.Equal(t, []byte("bar"), msg)
// TrySendMessage also works.
ok, err = ba.TrySendMessage(chID, []byte("try"))
require.NoError(t, err)
require.True(t, ok)
ch, msg, err = ab.ReceiveMessage()
require.NoError(t, err)
require.Equal(t, []byte("try"), msg)
require.Equal(t, chID, ch)
// Connections should still be active after closing the transports.
err = a.Close()
require.NoError(t, err)
err = b.Close()
require.NoError(t, err)
_, err = ab.SendMessage(chID, []byte("still here"))
require.NoError(t, err)
ch, msg, err = ba.ReceiveMessage()
require.NoError(t, err)
require.Equal(t, chID, ch)
require.Equal(t, []byte("still here"), msg)
// Close one side of the connection. Both sides should then error
// with io.EOF when trying to send or receive.
err = ba.Close()
require.NoError(t, err)
_, _, err = ab.ReceiveMessage()
require.Error(t, err)
require.Equal(t, io.EOF, err)
_, err = ab.SendMessage(chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
_, _, err = ba.ReceiveMessage()
require.Error(t, err)
require.Equal(t, io.EOF, err)
_, err = ba.SendMessage(chID, []byte("closed"))
require.Error(t, err)
require.Equal(t, io.EOF, err)
})
}
func TestConnection_Status(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, _ := dialAcceptHandshake(t, a, b)
// FIXME: This isn't implemented in all transports, so for now we just
// check that it doesn't panic, which isn't really much of a test.
ab.Status()
})
}
func TestConnection_String(t *testing.T) {
withTransports(t, func(t *testing.T, makeTransport transportFactory) {
a := makeTransport(t)
b := makeTransport(t)
ab, _ := dialAccept(t, a, b)
require.NotEmpty(t, ab.String())
})
}
func TestEndpoint_PeerAddress(t *testing.T) {
var (
ip4 = []byte{1, 2, 3, 4}
ip4in6 = net.IPv4(1, 2, 3, 4)
ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
)
testcases := []struct {
endpoint p2p.Endpoint
expect p2p.PeerAddress
}{
// Valid endpoints.
{
p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"},
p2p.PeerAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
},
{
p2p.Endpoint{Protocol: "tcp", IP: ip4in6, Port: 8080, Path: "path"},
p2p.PeerAddress{Protocol: "tcp", Hostname: "1.2.3.4", Port: 8080, Path: "path"},
},
{
p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "path"},
p2p.PeerAddress{Protocol: "tcp", Hostname: "b10c::1", Port: 8080, Path: "path"},
},
{
p2p.Endpoint{Protocol: "memory", Path: "foo"},
p2p.PeerAddress{Protocol: "memory", Path: "foo"},
},
// Partial (invalid) endpoints.
{p2p.Endpoint{}, p2p.PeerAddress{}},
{p2p.Endpoint{Protocol: "tcp"}, p2p.PeerAddress{Protocol: "tcp"}},
{p2p.Endpoint{IP: net.IPv4(1, 2, 3, 4)}, p2p.PeerAddress{Hostname: "1.2.3.4"}},
{p2p.Endpoint{Port: 8080}, p2p.PeerAddress{}},
{p2p.Endpoint{Path: "path"}, p2p.PeerAddress{Path: "path"}},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.endpoint.String(), func(t *testing.T) {
// Without NodeID.
expect := tc.expect
require.Equal(t, expect, tc.endpoint.PeerAddress(""))
// With NodeID.
expect.NodeID = p2p.NodeID("b10c")
require.Equal(t, expect, tc.endpoint.PeerAddress(expect.NodeID))
})
}
}
func TestEndpoint_String(t *testing.T) {
var (
ip4 = []byte{1, 2, 3, 4}
ip4in6 = net.IPv4(1, 2, 3, 4)
ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
)
testcases := []struct {
endpoint p2p.Endpoint
expect string
}{
// Non-networked endpoints.
{p2p.Endpoint{Protocol: "memory", Path: "foo"}, "memory:foo"},
{p2p.Endpoint{Protocol: "memory", Path: "👋"}, "memory:👋"},
// IPv4 endpoints.
{p2p.Endpoint{Protocol: "tcp", IP: ip4}, "tcp://1.2.3.4"},
{p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, "tcp://1.2.3.4"},
{p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080}, "tcp://1.2.3.4:8080"},
{p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "/path"}, "tcp://1.2.3.4:8080/path"},
{p2p.Endpoint{Protocol: "tcp", IP: ip4, Path: "path/👋"}, "tcp://1.2.3.4/path/%F0%9F%91%8B"},
// IPv6 endpoints.
{p2p.Endpoint{Protocol: "tcp", IP: ip6}, "tcp://b10c::1"},
{p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080}, "tcp://[b10c::1]:8080"},
{p2p.Endpoint{Protocol: "tcp", IP: ip6, Port: 8080, Path: "/path"}, "tcp://[b10c::1]:8080/path"},
{p2p.Endpoint{Protocol: "tcp", IP: ip6, Path: "path/👋"}, "tcp://b10c::1/path/%F0%9F%91%8B"},
// Partial (invalid) endpoints.
{p2p.Endpoint{}, ""},
{p2p.Endpoint{Protocol: "tcp"}, "tcp:"},
{p2p.Endpoint{IP: []byte{1, 2, 3, 4}}, "1.2.3.4"},
{p2p.Endpoint{IP: []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}}, "b10c::1"},
{p2p.Endpoint{Port: 8080}, ""},
{p2p.Endpoint{Path: "foo"}, "/foo"},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.expect, func(t *testing.T) {
require.Equal(t, tc.expect, tc.endpoint.String())
})
}
}
func TestEndpoint_Validate(t *testing.T) {
var (
ip4 = []byte{1, 2, 3, 4}
ip4in6 = net.IPv4(1, 2, 3, 4)
ip6 = []byte{0xb1, 0x0c, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0x01}
)
testcases := []struct {
endpoint p2p.Endpoint
expectValid bool
}{
// Valid endpoints.
{p2p.Endpoint{Protocol: "tcp", IP: ip4}, true},
{p2p.Endpoint{Protocol: "tcp", IP: ip4in6}, true},
{p2p.Endpoint{Protocol: "tcp", IP: ip6}, true},
{p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8008}, true},
{p2p.Endpoint{Protocol: "tcp", IP: ip4, Port: 8080, Path: "path"}, true},
{p2p.Endpoint{Protocol: "memory", Path: "path"}, true},
// Invalid endpoints.
{p2p.Endpoint{}, false},
{p2p.Endpoint{IP: ip4}, false},
{p2p.Endpoint{Protocol: "tcp"}, false},
{p2p.Endpoint{Protocol: "tcp", IP: []byte{1, 2, 3}}, false},
{p2p.Endpoint{Protocol: "tcp", Port: 8080, Path: "path"}, false},
}
for _, tc := range testcases {
tc := tc
t.Run(tc.endpoint.String(), func(t *testing.T) {
err := tc.endpoint.Validate()
if tc.expectValid {
require.NoError(t, err)
} else {
require.Error(t, err)
}
})
}
}
// dialAccept is a helper that dials b from a and returns both sides of the
// connection.
func dialAccept(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
t.Helper()
endpoints := b.Endpoints()
require.NotEmpty(t, endpoints, "peer not listening on any endpoints")
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
acceptCh := make(chan p2p.Connection, 1)
errCh := make(chan error, 1)
go func() {
conn, err := b.Accept()
errCh <- err
acceptCh <- conn
}()
dialConn, err := a.Dial(ctx, endpoints[0])
require.NoError(t, err)
acceptConn := <-acceptCh
require.NoError(t, <-errCh)
t.Cleanup(func() {
_ = dialConn.Close()
_ = acceptConn.Close()
})
return dialConn, acceptConn
}
// dialAcceptHandshake is a helper that dials and handshakes b from a and
// returns both sides of the connection.
func dialAcceptHandshake(t *testing.T, a, b p2p.Transport) (p2p.Connection, p2p.Connection) {
t.Helper()
ab, ba := dialAccept(t, a, b)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
errCh := make(chan error, 1)
go func() {
privKey := ed25519.GenPrivKey()
nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())}
_, _, err := ba.Handshake(ctx, nodeInfo, privKey)
errCh <- err
}()
privKey := ed25519.GenPrivKey()
nodeInfo := p2p.NodeInfo{NodeID: p2p.NodeIDFromPubKey(privKey.PubKey())}
_, _, err := ab.Handshake(ctx, nodeInfo, privKey)
require.NoError(t, err)
timer := time.NewTimer(2 * time.Second)
defer timer.Stop()
select {
case err := <-errCh:
require.NoError(t, err)
case <-timer.C:
require.Fail(t, "handshake timed out")
}
return ab, ba
}

Loading…
Cancel
Save