|
|
- package p2p
-
- import (
- "context"
- "errors"
- "fmt"
- "io"
- "sync"
-
- "github.com/tendermint/tendermint/crypto"
- "github.com/tendermint/tendermint/crypto/ed25519"
- "github.com/tendermint/tendermint/libs/log"
- "github.com/tendermint/tendermint/p2p/conn"
- )
-
- const (
- MemoryProtocol Protocol = "memory"
- )
-
- // MemoryNetwork is an in-memory "network" that uses Go channels to communicate
- // between endpoints. Transport endpoints are created with CreateTransport. It
- // is primarily used for testing.
- type MemoryNetwork struct {
- logger log.Logger
-
- mtx sync.RWMutex
- transports map[NodeID]*MemoryTransport
- }
-
- // NewMemoryNetwork creates a new in-memory network.
- func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
- return &MemoryNetwork{
- logger: logger,
- transports: map[NodeID]*MemoryTransport{},
- }
- }
-
- // CreateTransport creates a new memory transport and endpoint for the given
- // NodeInfo and private key. Use GenerateTransport() to autogenerate a random
- // key and node info.
- //
- // The transport immediately begins listening on the endpoint "memory:<id>", and
- // can be accessed by other transports in the same memory network.
- func (n *MemoryNetwork) CreateTransport(
- nodeInfo NodeInfo,
- privKey crypto.PrivKey,
- ) (*MemoryTransport, error) {
- nodeID := nodeInfo.NodeID
- if nodeID == "" {
- return nil, errors.New("no node ID")
- }
- t := newMemoryTransport(n, nodeInfo, privKey)
-
- n.mtx.Lock()
- defer n.mtx.Unlock()
- if _, ok := n.transports[nodeID]; ok {
- return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
- }
- n.transports[nodeID] = t
- return t, nil
- }
-
- // GenerateTransport generates a new transport endpoint by generating a random
- // private key and node info. The endpoint address can be obtained via
- // Transport.Endpoints().
- func (n *MemoryNetwork) GenerateTransport() *MemoryTransport {
- privKey := ed25519.GenPrivKey()
- nodeID := NodeIDFromPubKey(privKey.PubKey())
- nodeInfo := NodeInfo{
- NodeID: nodeID,
- ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID),
- }
- t, err := n.CreateTransport(nodeInfo, privKey)
- if err != nil {
- // GenerateTransport is only used for testing, and the likelihood of
- // generating a duplicate node ID is very low, so we'll panic.
- panic(err)
- }
- return t
- }
-
- // GetTransport looks up a transport in the network, returning nil if not found.
- func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
- n.mtx.RLock()
- defer n.mtx.RUnlock()
- return n.transports[id]
- }
-
- // RemoveTransport removes a transport from the network and closes it.
- func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
- n.mtx.Lock()
- t, ok := n.transports[id]
- delete(n.transports, id)
- n.mtx.Unlock()
-
- if ok {
- // Close may recursively call RemoveTransport() again, but this is safe
- // because we've already removed the transport from the map above.
- return t.Close()
- }
- return nil
- }
-
- // MemoryTransport is an in-memory transport that's primarily meant for testing.
- // It communicates between endpoints using Go channels. To dial a different
- // endpoint, both endpoints/transports must be in the same MemoryNetwork.
- type MemoryTransport struct {
- network *MemoryNetwork
- nodeInfo NodeInfo
- privKey crypto.PrivKey
- logger log.Logger
-
- acceptCh chan *MemoryConnection
- closeCh chan struct{}
- closeOnce sync.Once
- }
-
- // newMemoryTransport creates a new in-memory transport in the given network.
- // Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
- // to create transports, this is for internal use by MemoryNetwork.
- func newMemoryTransport(
- network *MemoryNetwork,
- nodeInfo NodeInfo,
- privKey crypto.PrivKey,
- ) *MemoryTransport {
- return &MemoryTransport{
- network: network,
- nodeInfo: nodeInfo,
- privKey: privKey,
- logger: network.logger.With("local",
- fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.NodeID)),
-
- acceptCh: make(chan *MemoryConnection),
- closeCh: make(chan struct{}),
- }
- }
-
- // String displays the transport.
- //
- // FIXME: The Transport interface should either have Name() or embed
- // fmt.Stringer. This is necessary since we log the transport (to know which one
- // it is), and if it doesn't implement fmt.Stringer then it inspects all struct
- // contents via reflect, which triggers the race detector.
- func (t *MemoryTransport) String() string {
- return "memory"
- }
-
- // Accept implements Transport.
- func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
- select {
- case conn := <-t.acceptCh:
- t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
- return conn, nil
- case <-t.closeCh:
- return nil, ErrTransportClosed{}
- case <-ctx.Done():
- return nil, ctx.Err()
- }
- }
-
- // Dial implements Transport.
- func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
- if endpoint.Protocol != MemoryProtocol {
- return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
- }
- if endpoint.PeerID == "" {
- return nil, errors.New("no peer ID")
- }
- t.logger.Info("dialing peer", "remote", endpoint)
-
- peerTransport := t.network.GetTransport(endpoint.PeerID)
- if peerTransport == nil {
- return nil, fmt.Errorf("unknown peer %q", endpoint.PeerID)
- }
- inCh := make(chan memoryMessage, 1)
- outCh := make(chan memoryMessage, 1)
- closeCh := make(chan struct{})
- closeOnce := sync.Once{}
- closer := func() bool {
- closed := false
- closeOnce.Do(func() {
- close(closeCh)
- closed = true
- })
- return closed
- }
-
- outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer)
- inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer)
-
- select {
- case peerTransport.acceptCh <- inConn:
- return outConn, nil
- case <-peerTransport.closeCh:
- return nil, ErrTransportClosed{}
- case <-ctx.Done():
- return nil, ctx.Err()
- }
- }
-
- // DialAccept is a convenience function that dials a peer MemoryTransport and
- // returns both ends of the connection (A to B and B to A).
- func (t *MemoryTransport) DialAccept(
- ctx context.Context,
- peer *MemoryTransport,
- ) (Connection, Connection, error) {
- endpoints := peer.Endpoints()
- if len(endpoints) == 0 {
- return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.NodeID)
- }
-
- acceptCh := make(chan Connection, 1)
- errCh := make(chan error, 1)
- go func() {
- conn, err := peer.Accept(ctx)
- errCh <- err
- acceptCh <- conn
- }()
-
- outConn, err := t.Dial(ctx, endpoints[0])
- if err != nil {
- return nil, nil, err
- }
- if err = <-errCh; err != nil {
- return nil, nil, err
- }
- inConn := <-acceptCh
-
- return outConn, inConn, nil
- }
-
- // Close implements Transport.
- func (t *MemoryTransport) Close() error {
- err := t.network.RemoveTransport(t.nodeInfo.NodeID)
- t.closeOnce.Do(func() {
- close(t.closeCh)
- })
- t.logger.Info("stopped accepting connections")
- return err
- }
-
- // Endpoints implements Transport.
- func (t *MemoryTransport) Endpoints() []Endpoint {
- select {
- case <-t.closeCh:
- return []Endpoint{}
- default:
- return []Endpoint{{
- Protocol: MemoryProtocol,
- PeerID: t.nodeInfo.NodeID,
- }}
- }
- }
-
- // SetChannelDescriptors implements Transport.
- func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) {
- }
-
- // MemoryConnection is an in-memory connection between two transports (nodes).
- type MemoryConnection struct {
- logger log.Logger
- local *MemoryTransport
- remote *MemoryTransport
-
- receiveCh <-chan memoryMessage
- sendCh chan<- memoryMessage
- closeCh <-chan struct{}
- close func() bool
- }
-
- // memoryMessage is used to pass messages internally in the connection.
- type memoryMessage struct {
- channel byte
- message []byte
- }
-
- // newMemoryConnection creates a new MemoryConnection. It takes all channels
- // (including the closeCh signal channel) on construction, such that they can be
- // shared between both ends of the connection.
- func newMemoryConnection(
- local *MemoryTransport,
- remote *MemoryTransport,
- receiveCh <-chan memoryMessage,
- sendCh chan<- memoryMessage,
- closeCh <-chan struct{},
- close func() bool,
- ) *MemoryConnection {
- c := &MemoryConnection{
- local: local,
- remote: remote,
- receiveCh: receiveCh,
- sendCh: sendCh,
- closeCh: closeCh,
- close: close,
- }
- c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
- return c
- }
-
- // ReceiveMessage implements Connection.
- func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
- // check close first, since channels are buffered
- select {
- case <-c.closeCh:
- return 0, nil, io.EOF
- default:
- }
-
- select {
- case msg := <-c.receiveCh:
- c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
- return msg.channel, msg.message, nil
- case <-c.closeCh:
- return 0, nil, io.EOF
- }
- }
-
- // SendMessage implements Connection.
- func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
- // check close first, since channels are buffered
- select {
- case <-c.closeCh:
- return false, io.EOF
- default:
- }
-
- select {
- case c.sendCh <- memoryMessage{channel: chID, message: msg}:
- c.logger.Debug("sent message", "channel", chID, "message", msg)
- return true, nil
- case <-c.closeCh:
- return false, io.EOF
- }
- }
-
- // TrySendMessage implements Connection.
- func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
- // check close first, since channels are buffered
- select {
- case <-c.closeCh:
- return false, io.EOF
- default:
- }
-
- select {
- case c.sendCh <- memoryMessage{channel: chID, message: msg}:
- c.logger.Debug("sent message", "channel", chID, "message", msg)
- return true, nil
- case <-c.closeCh:
- return false, io.EOF
- default:
- return false, nil
- }
- }
-
- // Close closes the connection.
- func (c *MemoryConnection) Close() error {
- if c.close() {
- c.logger.Info("closed connection")
- }
- return nil
- }
-
- // FlushClose flushes all pending sends and then closes the connection.
- func (c *MemoryConnection) FlushClose() error {
- return c.Close()
- }
-
- // LocalEndpoint returns the local endpoint for the connection.
- func (c *MemoryConnection) LocalEndpoint() Endpoint {
- return Endpoint{
- PeerID: c.local.nodeInfo.NodeID,
- Protocol: MemoryProtocol,
- }
- }
-
- // RemoteEndpoint returns the remote endpoint for the connection.
- func (c *MemoryConnection) RemoteEndpoint() Endpoint {
- return Endpoint{
- PeerID: c.remote.nodeInfo.NodeID,
- Protocol: MemoryProtocol,
- }
- }
-
- // PubKey returns the remote peer's public key.
- func (c *MemoryConnection) PubKey() crypto.PubKey {
- return c.remote.privKey.PubKey()
- }
-
- // NodeInfo returns the remote peer's node info.
- func (c *MemoryConnection) NodeInfo() NodeInfo {
- return c.remote.nodeInfo
- }
-
- // Status returns the current connection status.
- func (c *MemoryConnection) Status() conn.ConnectionStatus {
- return conn.ConnectionStatus{}
- }
|