You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

343 lines
8.9 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. tmsync "github.com/tendermint/tendermint/internal/libs/sync"
  11. "github.com/tendermint/tendermint/libs/log"
  12. "github.com/tendermint/tendermint/types"
  13. )
  14. const (
  15. MemoryProtocol Protocol = "memory"
  16. )
  17. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  18. // communicate between endpoints. It is primarily meant for testing.
  19. //
  20. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  21. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  22. type MemoryNetwork struct {
  23. logger log.Logger
  24. mtx sync.RWMutex
  25. transports map[types.NodeID]*MemoryTransport
  26. bufferSize int
  27. }
  28. // NewMemoryNetwork creates a new in-memory network.
  29. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  30. return &MemoryNetwork{
  31. bufferSize: bufferSize,
  32. logger: logger,
  33. transports: map[types.NodeID]*MemoryTransport{},
  34. }
  35. }
  36. // CreateTransport creates a new memory transport endpoint with the given node
  37. // ID and immediately begins listening on the address "memory:<id>". It panics
  38. // if the node ID is already in use (which is fine, since this is for tests).
  39. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  40. t := newMemoryTransport(n, nodeID)
  41. n.mtx.Lock()
  42. defer n.mtx.Unlock()
  43. if _, ok := n.transports[nodeID]; ok {
  44. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  45. }
  46. n.transports[nodeID] = t
  47. return t
  48. }
  49. // GetTransport looks up a transport in the network, returning nil if not found.
  50. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  51. n.mtx.RLock()
  52. defer n.mtx.RUnlock()
  53. return n.transports[id]
  54. }
  55. // RemoveTransport removes a transport from the network and closes it.
  56. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  57. n.mtx.Lock()
  58. t, ok := n.transports[id]
  59. delete(n.transports, id)
  60. n.mtx.Unlock()
  61. if ok {
  62. // Close may recursively call RemoveTransport() again, but this is safe
  63. // because we've already removed the transport from the map above.
  64. if err := t.Close(); err != nil {
  65. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  66. }
  67. }
  68. }
  69. // Size returns the number of transports in the network.
  70. func (n *MemoryNetwork) Size() int {
  71. return len(n.transports)
  72. }
  73. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  74. // communicate between endpoints. It is primarily meant for testing.
  75. //
  76. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  77. // a different endpoint, both transports must be in the same MemoryNetwork.
  78. type MemoryTransport struct {
  79. logger log.Logger
  80. network *MemoryNetwork
  81. nodeID types.NodeID
  82. bufferSize int
  83. acceptCh chan *MemoryConnection
  84. closeCh chan struct{}
  85. closeOnce sync.Once
  86. }
  87. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  88. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  89. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  90. return &MemoryTransport{
  91. logger: network.logger.With("local", nodeID),
  92. network: network,
  93. nodeID: nodeID,
  94. bufferSize: network.bufferSize,
  95. acceptCh: make(chan *MemoryConnection),
  96. closeCh: make(chan struct{}),
  97. }
  98. }
  99. // String implements Transport.
  100. func (t *MemoryTransport) String() string {
  101. return string(MemoryProtocol)
  102. }
  103. func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {}
  104. // Protocols implements Transport.
  105. func (t *MemoryTransport) Protocols() []Protocol {
  106. return []Protocol{MemoryProtocol}
  107. }
  108. // Endpoints implements Transport.
  109. func (t *MemoryTransport) Endpoints() []Endpoint {
  110. select {
  111. case <-t.closeCh:
  112. return []Endpoint{}
  113. default:
  114. return []Endpoint{{
  115. Protocol: MemoryProtocol,
  116. Path: string(t.nodeID),
  117. // An arbitrary IP and port is used in order for the pex
  118. // reactor to be able to send addresses to one another.
  119. IP: net.IPv4zero,
  120. Port: 0,
  121. }}
  122. }
  123. }
  124. // Accept implements Transport.
  125. func (t *MemoryTransport) Accept() (Connection, error) {
  126. select {
  127. case conn := <-t.acceptCh:
  128. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  129. return conn, nil
  130. case <-t.closeCh:
  131. return nil, io.EOF
  132. }
  133. }
  134. // Dial implements Transport.
  135. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  136. if endpoint.Protocol != MemoryProtocol {
  137. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  138. }
  139. if endpoint.Path == "" {
  140. return nil, errors.New("no path")
  141. }
  142. if err := endpoint.Validate(); err != nil {
  143. return nil, err
  144. }
  145. nodeID, err := types.NewNodeID(endpoint.Path)
  146. if err != nil {
  147. return nil, err
  148. }
  149. t.logger.Info("dialing peer", "remote", nodeID)
  150. peer := t.network.GetTransport(nodeID)
  151. if peer == nil {
  152. return nil, fmt.Errorf("unknown peer %q", nodeID)
  153. }
  154. inCh := make(chan memoryMessage, t.bufferSize)
  155. outCh := make(chan memoryMessage, t.bufferSize)
  156. closer := tmsync.NewCloser()
  157. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  158. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  159. select {
  160. case peer.acceptCh <- inConn:
  161. return outConn, nil
  162. case <-peer.closeCh:
  163. return nil, io.EOF
  164. case <-ctx.Done():
  165. return nil, ctx.Err()
  166. }
  167. }
  168. // Close implements Transport.
  169. func (t *MemoryTransport) Close() error {
  170. t.network.RemoveTransport(t.nodeID)
  171. t.closeOnce.Do(func() {
  172. close(t.closeCh)
  173. t.logger.Info("closed transport")
  174. })
  175. return nil
  176. }
  177. // MemoryConnection is an in-memory connection between two transport endpoints.
  178. type MemoryConnection struct {
  179. logger log.Logger
  180. localID types.NodeID
  181. remoteID types.NodeID
  182. receiveCh <-chan memoryMessage
  183. sendCh chan<- memoryMessage
  184. closer *tmsync.Closer
  185. }
  186. // memoryMessage is passed internally, containing either a message or handshake.
  187. type memoryMessage struct {
  188. channelID ChannelID
  189. message []byte
  190. // For handshakes.
  191. nodeInfo *types.NodeInfo
  192. pubKey crypto.PubKey
  193. }
  194. // newMemoryConnection creates a new MemoryConnection.
  195. func newMemoryConnection(
  196. logger log.Logger,
  197. localID types.NodeID,
  198. remoteID types.NodeID,
  199. receiveCh <-chan memoryMessage,
  200. sendCh chan<- memoryMessage,
  201. closer *tmsync.Closer,
  202. ) *MemoryConnection {
  203. return &MemoryConnection{
  204. logger: logger.With("remote", remoteID),
  205. localID: localID,
  206. remoteID: remoteID,
  207. receiveCh: receiveCh,
  208. sendCh: sendCh,
  209. closer: closer,
  210. }
  211. }
  212. // String implements Connection.
  213. func (c *MemoryConnection) String() string {
  214. return c.RemoteEndpoint().String()
  215. }
  216. // LocalEndpoint implements Connection.
  217. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  218. return Endpoint{
  219. Protocol: MemoryProtocol,
  220. Path: string(c.localID),
  221. }
  222. }
  223. // RemoteEndpoint implements Connection.
  224. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  225. return Endpoint{
  226. Protocol: MemoryProtocol,
  227. Path: string(c.remoteID),
  228. }
  229. }
  230. // Handshake implements Connection.
  231. func (c *MemoryConnection) Handshake(
  232. ctx context.Context,
  233. nodeInfo types.NodeInfo,
  234. privKey crypto.PrivKey,
  235. ) (types.NodeInfo, crypto.PubKey, error) {
  236. select {
  237. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  238. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  239. case <-c.closer.Done():
  240. return types.NodeInfo{}, nil, io.EOF
  241. case <-ctx.Done():
  242. return types.NodeInfo{}, nil, ctx.Err()
  243. }
  244. select {
  245. case msg := <-c.receiveCh:
  246. if msg.nodeInfo == nil {
  247. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  248. }
  249. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  250. return *msg.nodeInfo, msg.pubKey, nil
  251. case <-c.closer.Done():
  252. return types.NodeInfo{}, nil, io.EOF
  253. case <-ctx.Done():
  254. return types.NodeInfo{}, nil, ctx.Err()
  255. }
  256. }
  257. // ReceiveMessage implements Connection.
  258. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
  259. // Check close first, since channels are buffered. Otherwise, below select
  260. // may non-deterministically return non-error even when closed.
  261. select {
  262. case <-c.closer.Done():
  263. return 0, nil, io.EOF
  264. default:
  265. }
  266. select {
  267. case msg := <-c.receiveCh:
  268. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  269. return msg.channelID, msg.message, nil
  270. case <-c.closer.Done():
  271. return 0, nil, io.EOF
  272. }
  273. }
  274. // SendMessage implements Connection.
  275. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) error {
  276. // Check close first, since channels are buffered. Otherwise, below select
  277. // may non-deterministically return non-error even when closed.
  278. select {
  279. case <-c.closer.Done():
  280. return io.EOF
  281. default:
  282. }
  283. select {
  284. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  285. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  286. return nil
  287. case <-c.closer.Done():
  288. return io.EOF
  289. }
  290. }
  291. // Close implements Connection.
  292. func (c *MemoryConnection) Close() error {
  293. select {
  294. case <-c.closer.Done():
  295. return nil
  296. default:
  297. c.closer.Close()
  298. c.logger.Info("closed connection")
  299. }
  300. return nil
  301. }