You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

345 lines
9.0 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. tmsync "github.com/tendermint/tendermint/internal/libs/sync"
  11. "github.com/tendermint/tendermint/libs/log"
  12. "github.com/tendermint/tendermint/types"
  13. )
  14. const (
  15. MemoryProtocol Protocol = "memory"
  16. )
  17. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  18. // communicate between endpoints. It is primarily meant for testing.
  19. //
  20. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  21. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  22. type MemoryNetwork struct {
  23. logger log.Logger
  24. mtx sync.RWMutex
  25. transports map[types.NodeID]*MemoryTransport
  26. bufferSize int
  27. }
  28. // NewMemoryNetwork creates a new in-memory network.
  29. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  30. return &MemoryNetwork{
  31. bufferSize: bufferSize,
  32. logger: logger,
  33. transports: map[types.NodeID]*MemoryTransport{},
  34. }
  35. }
  36. // CreateTransport creates a new memory transport endpoint with the given node
  37. // ID and immediately begins listening on the address "memory:<id>". It panics
  38. // if the node ID is already in use (which is fine, since this is for tests).
  39. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  40. t := newMemoryTransport(n, nodeID)
  41. n.mtx.Lock()
  42. defer n.mtx.Unlock()
  43. if _, ok := n.transports[nodeID]; ok {
  44. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  45. }
  46. n.transports[nodeID] = t
  47. return t
  48. }
  49. // GetTransport looks up a transport in the network, returning nil if not found.
  50. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  51. n.mtx.RLock()
  52. defer n.mtx.RUnlock()
  53. return n.transports[id]
  54. }
  55. // RemoveTransport removes a transport from the network and closes it.
  56. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  57. n.mtx.Lock()
  58. t, ok := n.transports[id]
  59. delete(n.transports, id)
  60. n.mtx.Unlock()
  61. if ok {
  62. // Close may recursively call RemoveTransport() again, but this is safe
  63. // because we've already removed the transport from the map above.
  64. if err := t.Close(); err != nil {
  65. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  66. }
  67. }
  68. }
  69. // Size returns the number of transports in the network.
  70. func (n *MemoryNetwork) Size() int {
  71. return len(n.transports)
  72. }
  73. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  74. // communicate between endpoints. It is primarily meant for testing.
  75. //
  76. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  77. // a different endpoint, both transports must be in the same MemoryNetwork.
  78. type MemoryTransport struct {
  79. logger log.Logger
  80. network *MemoryNetwork
  81. nodeID types.NodeID
  82. bufferSize int
  83. acceptCh chan *MemoryConnection
  84. closeCh chan struct{}
  85. closeOnce sync.Once
  86. }
  87. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  88. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  89. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  90. return &MemoryTransport{
  91. logger: network.logger.With("local", nodeID),
  92. network: network,
  93. nodeID: nodeID,
  94. bufferSize: network.bufferSize,
  95. acceptCh: make(chan *MemoryConnection),
  96. closeCh: make(chan struct{}),
  97. }
  98. }
  99. // String implements Transport.
  100. func (t *MemoryTransport) String() string {
  101. return string(MemoryProtocol)
  102. }
  103. func (*MemoryTransport) Listen(Endpoint) error { return nil }
  104. func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {}
  105. // Protocols implements Transport.
  106. func (t *MemoryTransport) Protocols() []Protocol {
  107. return []Protocol{MemoryProtocol}
  108. }
  109. // Endpoints implements Transport.
  110. func (t *MemoryTransport) Endpoints() []Endpoint {
  111. select {
  112. case <-t.closeCh:
  113. return []Endpoint{}
  114. default:
  115. return []Endpoint{{
  116. Protocol: MemoryProtocol,
  117. Path: string(t.nodeID),
  118. // An arbitrary IP and port is used in order for the pex
  119. // reactor to be able to send addresses to one another.
  120. IP: net.IPv4zero,
  121. Port: 0,
  122. }}
  123. }
  124. }
  125. // Accept implements Transport.
  126. func (t *MemoryTransport) Accept() (Connection, error) {
  127. select {
  128. case conn := <-t.acceptCh:
  129. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  130. return conn, nil
  131. case <-t.closeCh:
  132. return nil, io.EOF
  133. }
  134. }
  135. // Dial implements Transport.
  136. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  137. if endpoint.Protocol != MemoryProtocol {
  138. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  139. }
  140. if endpoint.Path == "" {
  141. return nil, errors.New("no path")
  142. }
  143. if err := endpoint.Validate(); err != nil {
  144. return nil, err
  145. }
  146. nodeID, err := types.NewNodeID(endpoint.Path)
  147. if err != nil {
  148. return nil, err
  149. }
  150. t.logger.Info("dialing peer", "remote", nodeID)
  151. peer := t.network.GetTransport(nodeID)
  152. if peer == nil {
  153. return nil, fmt.Errorf("unknown peer %q", nodeID)
  154. }
  155. inCh := make(chan memoryMessage, t.bufferSize)
  156. outCh := make(chan memoryMessage, t.bufferSize)
  157. closer := tmsync.NewCloser()
  158. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  159. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  160. select {
  161. case peer.acceptCh <- inConn:
  162. return outConn, nil
  163. case <-peer.closeCh:
  164. return nil, io.EOF
  165. case <-ctx.Done():
  166. return nil, ctx.Err()
  167. }
  168. }
  169. // Close implements Transport.
  170. func (t *MemoryTransport) Close() error {
  171. t.network.RemoveTransport(t.nodeID)
  172. t.closeOnce.Do(func() {
  173. close(t.closeCh)
  174. t.logger.Info("closed transport")
  175. })
  176. return nil
  177. }
  178. // MemoryConnection is an in-memory connection between two transport endpoints.
  179. type MemoryConnection struct {
  180. logger log.Logger
  181. localID types.NodeID
  182. remoteID types.NodeID
  183. receiveCh <-chan memoryMessage
  184. sendCh chan<- memoryMessage
  185. closer *tmsync.Closer
  186. }
  187. // memoryMessage is passed internally, containing either a message or handshake.
  188. type memoryMessage struct {
  189. channelID ChannelID
  190. message []byte
  191. // For handshakes.
  192. nodeInfo *types.NodeInfo
  193. pubKey crypto.PubKey
  194. }
  195. // newMemoryConnection creates a new MemoryConnection.
  196. func newMemoryConnection(
  197. logger log.Logger,
  198. localID types.NodeID,
  199. remoteID types.NodeID,
  200. receiveCh <-chan memoryMessage,
  201. sendCh chan<- memoryMessage,
  202. closer *tmsync.Closer,
  203. ) *MemoryConnection {
  204. return &MemoryConnection{
  205. logger: logger.With("remote", remoteID),
  206. localID: localID,
  207. remoteID: remoteID,
  208. receiveCh: receiveCh,
  209. sendCh: sendCh,
  210. closer: closer,
  211. }
  212. }
  213. // String implements Connection.
  214. func (c *MemoryConnection) String() string {
  215. return c.RemoteEndpoint().String()
  216. }
  217. // LocalEndpoint implements Connection.
  218. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  219. return Endpoint{
  220. Protocol: MemoryProtocol,
  221. Path: string(c.localID),
  222. }
  223. }
  224. // RemoteEndpoint implements Connection.
  225. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  226. return Endpoint{
  227. Protocol: MemoryProtocol,
  228. Path: string(c.remoteID),
  229. }
  230. }
  231. // Handshake implements Connection.
  232. func (c *MemoryConnection) Handshake(
  233. ctx context.Context,
  234. nodeInfo types.NodeInfo,
  235. privKey crypto.PrivKey,
  236. ) (types.NodeInfo, crypto.PubKey, error) {
  237. select {
  238. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  239. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  240. case <-c.closer.Done():
  241. return types.NodeInfo{}, nil, io.EOF
  242. case <-ctx.Done():
  243. return types.NodeInfo{}, nil, ctx.Err()
  244. }
  245. select {
  246. case msg := <-c.receiveCh:
  247. if msg.nodeInfo == nil {
  248. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  249. }
  250. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  251. return *msg.nodeInfo, msg.pubKey, nil
  252. case <-c.closer.Done():
  253. return types.NodeInfo{}, nil, io.EOF
  254. case <-ctx.Done():
  255. return types.NodeInfo{}, nil, ctx.Err()
  256. }
  257. }
  258. // ReceiveMessage implements Connection.
  259. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
  260. // Check close first, since channels are buffered. Otherwise, below select
  261. // may non-deterministically return non-error even when closed.
  262. select {
  263. case <-c.closer.Done():
  264. return 0, nil, io.EOF
  265. default:
  266. }
  267. select {
  268. case msg := <-c.receiveCh:
  269. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  270. return msg.channelID, msg.message, nil
  271. case <-c.closer.Done():
  272. return 0, nil, io.EOF
  273. }
  274. }
  275. // SendMessage implements Connection.
  276. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error {
  277. // Check close first, since channels are buffered. Otherwise, below select
  278. // may non-deterministically return non-error even when closed.
  279. select {
  280. case <-c.closer.Done():
  281. return io.EOF
  282. default:
  283. }
  284. select {
  285. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  286. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  287. return nil
  288. case <-c.closer.Done():
  289. return io.EOF
  290. }
  291. }
  292. // Close implements Connection.
  293. func (c *MemoryConnection) Close() error {
  294. select {
  295. case <-c.closer.Done():
  296. return nil
  297. default:
  298. c.closer.Close()
  299. c.logger.Info("closed connection")
  300. }
  301. return nil
  302. }