package p2p import ( "context" "errors" "fmt" "io" "sync" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/libs/log" tmsync "github.com/tendermint/tendermint/libs/sync" "github.com/tendermint/tendermint/p2p/conn" ) const ( MemoryProtocol Protocol = "memory" // bufferSize is the channel buffer size of MemoryConnection. bufferSize = 1 ) // MemoryNetwork is an in-memory "network" that uses buffered Go channels to // communicate between endpoints. It is primarily meant for testing. // // Network endpoints are allocated via CreateTransport(), which takes a node ID, // and the endpoint is then immediately accessible via the URL "memory:". type MemoryNetwork struct { logger log.Logger mtx sync.RWMutex transports map[NodeID]*MemoryTransport } // NewMemoryNetwork creates a new in-memory network. func NewMemoryNetwork(logger log.Logger) *MemoryNetwork { return &MemoryNetwork{ logger: logger, transports: map[NodeID]*MemoryTransport{}, } } // CreateTransport creates a new memory transport endpoint with the given node // ID and immediately begins listening on the address "memory:". It panics // if the node ID is already in use (which is fine, since this is for tests). func (n *MemoryNetwork) CreateTransport(nodeID NodeID) *MemoryTransport { t := newMemoryTransport(n, nodeID) n.mtx.Lock() defer n.mtx.Unlock() if _, ok := n.transports[nodeID]; ok { panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID)) } n.transports[nodeID] = t return t } // GetTransport looks up a transport in the network, returning nil if not found. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport { n.mtx.RLock() defer n.mtx.RUnlock() return n.transports[id] } // RemoveTransport removes a transport from the network and closes it. func (n *MemoryNetwork) RemoveTransport(id NodeID) { n.mtx.Lock() t, ok := n.transports[id] delete(n.transports, id) n.mtx.Unlock() if ok { // Close may recursively call RemoveTransport() again, but this is safe // because we've already removed the transport from the map above. if err := t.Close(); err != nil { n.logger.Error("failed to close memory transport", "id", id, "err", err) } } } // Size returns the number of transports in the network. func (n *MemoryNetwork) Size() int { return len(n.transports) } // MemoryTransport is an in-memory transport that uses buffered Go channels to // communicate between endpoints. It is primarily meant for testing. // // New transports are allocated with MemoryNetwork.CreateTransport(). To contact // a different endpoint, both transports must be in the same MemoryNetwork. type MemoryTransport struct { logger log.Logger network *MemoryNetwork nodeID NodeID acceptCh chan *MemoryConnection closeCh chan struct{} closeOnce sync.Once } // newMemoryTransport creates a new MemoryTransport. This is for internal use by // MemoryNetwork, use MemoryNetwork.CreateTransport() instead. func newMemoryTransport(network *MemoryNetwork, nodeID NodeID) *MemoryTransport { return &MemoryTransport{ logger: network.logger.With("local", nodeID), network: network, nodeID: nodeID, acceptCh: make(chan *MemoryConnection), closeCh: make(chan struct{}), } } // String implements Transport. func (t *MemoryTransport) String() string { return string(MemoryProtocol) } // Protocols implements Transport. func (t *MemoryTransport) Protocols() []Protocol { return []Protocol{MemoryProtocol} } // Endpoints implements Transport. func (t *MemoryTransport) Endpoints() []Endpoint { select { case <-t.closeCh: return []Endpoint{} default: return []Endpoint{{ Protocol: MemoryProtocol, Path: string(t.nodeID), }} } } // Accept implements Transport. func (t *MemoryTransport) Accept() (Connection, error) { select { case conn := <-t.acceptCh: t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path) return conn, nil case <-t.closeCh: return nil, io.EOF } } // Dial implements Transport. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) { if endpoint.Protocol != MemoryProtocol { return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol) } if endpoint.Path == "" { return nil, errors.New("no path") } nodeID, err := NewNodeID(endpoint.Path) if err != nil { return nil, err } t.logger.Info("dialing peer", "remote", nodeID) peer := t.network.GetTransport(nodeID) if peer == nil { return nil, fmt.Errorf("unknown peer %q", nodeID) } inCh := make(chan memoryMessage, bufferSize) outCh := make(chan memoryMessage, bufferSize) closer := tmsync.NewCloser() outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer) inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer) select { case peer.acceptCh <- inConn: return outConn, nil case <-peer.closeCh: return nil, io.EOF case <-ctx.Done(): return nil, ctx.Err() } } // Close implements Transport. func (t *MemoryTransport) Close() error { t.network.RemoveTransport(t.nodeID) t.closeOnce.Do(func() { close(t.closeCh) t.logger.Info("closed transport") }) return nil } // MemoryConnection is an in-memory connection between two transport endpoints. type MemoryConnection struct { logger log.Logger localID NodeID remoteID NodeID receiveCh <-chan memoryMessage sendCh chan<- memoryMessage closer *tmsync.Closer } // memoryMessage is passed internally, containing either a message or handshake. type memoryMessage struct { channelID ChannelID message []byte // For handshakes. nodeInfo *NodeInfo pubKey crypto.PubKey } // newMemoryConnection creates a new MemoryConnection. func newMemoryConnection( logger log.Logger, localID NodeID, remoteID NodeID, receiveCh <-chan memoryMessage, sendCh chan<- memoryMessage, closer *tmsync.Closer, ) *MemoryConnection { return &MemoryConnection{ logger: logger.With("remote", remoteID), localID: localID, remoteID: remoteID, receiveCh: receiveCh, sendCh: sendCh, closer: closer, } } // String implements Connection. func (c *MemoryConnection) String() string { return c.RemoteEndpoint().String() } // LocalEndpoint implements Connection. func (c *MemoryConnection) LocalEndpoint() Endpoint { return Endpoint{ Protocol: MemoryProtocol, Path: string(c.localID), } } // RemoteEndpoint implements Connection. func (c *MemoryConnection) RemoteEndpoint() Endpoint { return Endpoint{ Protocol: MemoryProtocol, Path: string(c.remoteID), } } // Status implements Connection. func (c *MemoryConnection) Status() conn.ConnectionStatus { return conn.ConnectionStatus{} } // Handshake implements Connection. func (c *MemoryConnection) Handshake( ctx context.Context, nodeInfo NodeInfo, privKey crypto.PrivKey, ) (NodeInfo, crypto.PubKey, error) { select { case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}: c.logger.Debug("sent handshake", "nodeInfo", nodeInfo) case <-c.closer.Done(): return NodeInfo{}, nil, io.EOF case <-ctx.Done(): return NodeInfo{}, nil, ctx.Err() } select { case msg := <-c.receiveCh: if msg.nodeInfo == nil { return NodeInfo{}, nil, errors.New("no NodeInfo in handshake") } c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo) return *msg.nodeInfo, msg.pubKey, nil case <-c.closer.Done(): return NodeInfo{}, nil, io.EOF case <-ctx.Done(): return NodeInfo{}, nil, ctx.Err() } } // ReceiveMessage implements Connection. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) { // Check close first, since channels are buffered. Otherwise, below select // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return 0, nil, io.EOF default: } select { case msg := <-c.receiveCh: c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message) return msg.channelID, msg.message, nil case <-c.closer.Done(): return 0, nil, io.EOF } } // SendMessage implements Connection. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) { // Check close first, since channels are buffered. Otherwise, below select // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return false, io.EOF default: } select { case c.sendCh <- memoryMessage{channelID: chID, message: msg}: c.logger.Debug("sent message", "chID", chID, "msg", msg) return true, nil case <-c.closer.Done(): return false, io.EOF } } // TrySendMessage implements Connection. func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) { // Check close first, since channels are buffered. Otherwise, below select // may non-deterministically return non-error even when closed. select { case <-c.closer.Done(): return false, io.EOF default: } select { case c.sendCh <- memoryMessage{channelID: chID, message: msg}: c.logger.Debug("sent message", "chID", chID, "msg", msg) return true, nil case <-c.closer.Done(): return false, io.EOF default: return false, nil } } // Close implements Connection. func (c *MemoryConnection) Close() error { select { case <-c.closer.Done(): return nil default: c.closer.Close() c.logger.Info("closed connection") } return nil } // FlushClose implements Connection. func (c *MemoryConnection) FlushClose() error { return c.Close() }