You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

341 lines
9.0 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. tmsync "github.com/tendermint/tendermint/internal/libs/sync"
  11. "github.com/tendermint/tendermint/libs/log"
  12. "github.com/tendermint/tendermint/types"
  13. )
  14. const (
  15. MemoryProtocol Protocol = "memory"
  16. )
  17. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  18. // communicate between endpoints. It is primarily meant for testing.
  19. //
  20. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  21. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  22. type MemoryNetwork struct {
  23. logger log.Logger
  24. mtx sync.RWMutex
  25. transports map[types.NodeID]*MemoryTransport
  26. bufferSize int
  27. }
  28. // NewMemoryNetwork creates a new in-memory network.
  29. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  30. return &MemoryNetwork{
  31. bufferSize: bufferSize,
  32. logger: logger,
  33. transports: map[types.NodeID]*MemoryTransport{},
  34. }
  35. }
  36. // CreateTransport creates a new memory transport endpoint with the given node
  37. // ID and immediately begins listening on the address "memory:<id>". It panics
  38. // if the node ID is already in use (which is fine, since this is for tests).
  39. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  40. t := newMemoryTransport(n, nodeID)
  41. n.mtx.Lock()
  42. defer n.mtx.Unlock()
  43. if _, ok := n.transports[nodeID]; ok {
  44. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  45. }
  46. n.transports[nodeID] = t
  47. return t
  48. }
  49. // GetTransport looks up a transport in the network, returning nil if not found.
  50. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  51. n.mtx.RLock()
  52. defer n.mtx.RUnlock()
  53. return n.transports[id]
  54. }
  55. // RemoveTransport removes a transport from the network and closes it.
  56. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  57. n.mtx.Lock()
  58. t, ok := n.transports[id]
  59. delete(n.transports, id)
  60. n.mtx.Unlock()
  61. if ok {
  62. // Close may recursively call RemoveTransport() again, but this is safe
  63. // because we've already removed the transport from the map above.
  64. if err := t.Close(); err != nil {
  65. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  66. }
  67. }
  68. }
  69. // Size returns the number of transports in the network.
  70. func (n *MemoryNetwork) Size() int {
  71. return len(n.transports)
  72. }
  73. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  74. // communicate between endpoints. It is primarily meant for testing.
  75. //
  76. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  77. // a different endpoint, both transports must be in the same MemoryNetwork.
  78. type MemoryTransport struct {
  79. logger log.Logger
  80. network *MemoryNetwork
  81. nodeID types.NodeID
  82. bufferSize int
  83. acceptCh chan *MemoryConnection
  84. }
  85. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  86. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  87. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  88. return &MemoryTransport{
  89. logger: network.logger.With("local", nodeID),
  90. network: network,
  91. nodeID: nodeID,
  92. bufferSize: network.bufferSize,
  93. acceptCh: make(chan *MemoryConnection),
  94. }
  95. }
  96. // String implements Transport.
  97. func (t *MemoryTransport) String() string {
  98. return string(MemoryProtocol)
  99. }
  100. func (*MemoryTransport) Listen(Endpoint) error { return nil }
  101. func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {}
  102. // Protocols implements Transport.
  103. func (t *MemoryTransport) Protocols() []Protocol {
  104. return []Protocol{MemoryProtocol}
  105. }
  106. // Endpoints implements Transport.
  107. func (t *MemoryTransport) Endpoints() []Endpoint {
  108. if n := t.network.GetTransport(t.nodeID); n == nil {
  109. return []Endpoint{}
  110. }
  111. return []Endpoint{{
  112. Protocol: MemoryProtocol,
  113. Path: string(t.nodeID),
  114. // An arbitrary IP and port is used in order for the pex
  115. // reactor to be able to send addresses to one another.
  116. IP: net.IPv4zero,
  117. Port: 0,
  118. }}
  119. }
  120. // Accept implements Transport.
  121. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  122. select {
  123. case conn := <-t.acceptCh:
  124. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  125. return conn, nil
  126. case <-ctx.Done():
  127. return nil, io.EOF
  128. }
  129. }
  130. // Dial implements Transport.
  131. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  132. if endpoint.Protocol != MemoryProtocol {
  133. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  134. }
  135. if endpoint.Path == "" {
  136. return nil, errors.New("no path")
  137. }
  138. if err := endpoint.Validate(); err != nil {
  139. return nil, err
  140. }
  141. nodeID, err := types.NewNodeID(endpoint.Path)
  142. if err != nil {
  143. return nil, err
  144. }
  145. t.logger.Info("dialing peer", "remote", nodeID)
  146. peer := t.network.GetTransport(nodeID)
  147. if peer == nil {
  148. return nil, fmt.Errorf("unknown peer %q", nodeID)
  149. }
  150. inCh := make(chan memoryMessage, t.bufferSize)
  151. outCh := make(chan memoryMessage, t.bufferSize)
  152. closer := tmsync.NewCloser()
  153. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  154. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  155. select {
  156. case peer.acceptCh <- inConn:
  157. return outConn, nil
  158. case <-ctx.Done():
  159. return nil, io.EOF
  160. }
  161. }
  162. // Close implements Transport.
  163. func (t *MemoryTransport) Close() error {
  164. t.network.RemoveTransport(t.nodeID)
  165. return nil
  166. }
  167. // MemoryConnection is an in-memory connection between two transport endpoints.
  168. type MemoryConnection struct {
  169. logger log.Logger
  170. localID types.NodeID
  171. remoteID types.NodeID
  172. receiveCh <-chan memoryMessage
  173. sendCh chan<- memoryMessage
  174. closer *tmsync.Closer
  175. }
  176. // memoryMessage is passed internally, containing either a message or handshake.
  177. type memoryMessage struct {
  178. channelID ChannelID
  179. message []byte
  180. // For handshakes.
  181. nodeInfo *types.NodeInfo
  182. pubKey crypto.PubKey
  183. }
  184. // newMemoryConnection creates a new MemoryConnection.
  185. func newMemoryConnection(
  186. logger log.Logger,
  187. localID types.NodeID,
  188. remoteID types.NodeID,
  189. receiveCh <-chan memoryMessage,
  190. sendCh chan<- memoryMessage,
  191. closer *tmsync.Closer,
  192. ) *MemoryConnection {
  193. return &MemoryConnection{
  194. logger: logger.With("remote", remoteID),
  195. localID: localID,
  196. remoteID: remoteID,
  197. receiveCh: receiveCh,
  198. sendCh: sendCh,
  199. closer: closer,
  200. }
  201. }
  202. // String implements Connection.
  203. func (c *MemoryConnection) String() string {
  204. return c.RemoteEndpoint().String()
  205. }
  206. // LocalEndpoint implements Connection.
  207. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  208. return Endpoint{
  209. Protocol: MemoryProtocol,
  210. Path: string(c.localID),
  211. }
  212. }
  213. // RemoteEndpoint implements Connection.
  214. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  215. return Endpoint{
  216. Protocol: MemoryProtocol,
  217. Path: string(c.remoteID),
  218. }
  219. }
  220. // Handshake implements Connection.
  221. func (c *MemoryConnection) Handshake(
  222. ctx context.Context,
  223. nodeInfo types.NodeInfo,
  224. privKey crypto.PrivKey,
  225. ) (types.NodeInfo, crypto.PubKey, error) {
  226. select {
  227. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  228. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  229. case <-c.closer.Done():
  230. return types.NodeInfo{}, nil, io.EOF
  231. case <-ctx.Done():
  232. return types.NodeInfo{}, nil, ctx.Err()
  233. }
  234. select {
  235. case msg := <-c.receiveCh:
  236. if msg.nodeInfo == nil {
  237. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  238. }
  239. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  240. return *msg.nodeInfo, msg.pubKey, nil
  241. case <-c.closer.Done():
  242. return types.NodeInfo{}, nil, io.EOF
  243. case <-ctx.Done():
  244. return types.NodeInfo{}, nil, ctx.Err()
  245. }
  246. }
  247. // ReceiveMessage implements Connection.
  248. func (c *MemoryConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
  249. // Check close first, since channels are buffered. Otherwise, below select
  250. // may non-deterministically return non-error even when closed.
  251. select {
  252. case <-c.closer.Done():
  253. return 0, nil, io.EOF
  254. case <-ctx.Done():
  255. return 0, nil, io.EOF
  256. default:
  257. }
  258. select {
  259. case msg := <-c.receiveCh:
  260. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  261. return msg.channelID, msg.message, nil
  262. case <-c.closer.Done():
  263. return 0, nil, io.EOF
  264. }
  265. }
  266. // SendMessage implements Connection.
  267. func (c *MemoryConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
  268. // Check close first, since channels are buffered. Otherwise, below select
  269. // may non-deterministically return non-error even when closed.
  270. select {
  271. case <-c.closer.Done():
  272. return io.EOF
  273. case <-ctx.Done():
  274. return io.EOF
  275. default:
  276. }
  277. select {
  278. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  279. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  280. return nil
  281. case <-ctx.Done():
  282. return io.EOF
  283. case <-c.closer.Done():
  284. return io.EOF
  285. }
  286. }
  287. // Close implements Connection.
  288. func (c *MemoryConnection) Close() error {
  289. select {
  290. case <-c.closer.Done():
  291. return nil
  292. default:
  293. c.closer.Close()
  294. c.logger.Info("closed connection")
  295. }
  296. return nil
  297. }