You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

363 lines
9.3 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto"
  9. "github.com/tendermint/tendermint/libs/log"
  10. tmsync "github.com/tendermint/tendermint/libs/sync"
  11. "github.com/tendermint/tendermint/p2p/conn"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  17. // communicate between endpoints. It is primarily meant for testing.
  18. //
  19. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  20. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  21. type MemoryNetwork struct {
  22. logger log.Logger
  23. mtx sync.RWMutex
  24. transports map[NodeID]*MemoryTransport
  25. bufferSize int
  26. }
  27. // NewMemoryNetwork creates a new in-memory network.
  28. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  29. return &MemoryNetwork{
  30. bufferSize: bufferSize,
  31. logger: logger,
  32. transports: map[NodeID]*MemoryTransport{},
  33. }
  34. }
  35. // CreateTransport creates a new memory transport endpoint with the given node
  36. // ID and immediately begins listening on the address "memory:<id>". It panics
  37. // if the node ID is already in use (which is fine, since this is for tests).
  38. func (n *MemoryNetwork) CreateTransport(nodeID NodeID) *MemoryTransport {
  39. t := newMemoryTransport(n, nodeID)
  40. n.mtx.Lock()
  41. defer n.mtx.Unlock()
  42. if _, ok := n.transports[nodeID]; ok {
  43. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  44. }
  45. n.transports[nodeID] = t
  46. return t
  47. }
  48. // GetTransport looks up a transport in the network, returning nil if not found.
  49. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
  50. n.mtx.RLock()
  51. defer n.mtx.RUnlock()
  52. return n.transports[id]
  53. }
  54. // RemoveTransport removes a transport from the network and closes it.
  55. func (n *MemoryNetwork) RemoveTransport(id NodeID) {
  56. n.mtx.Lock()
  57. t, ok := n.transports[id]
  58. delete(n.transports, id)
  59. n.mtx.Unlock()
  60. if ok {
  61. // Close may recursively call RemoveTransport() again, but this is safe
  62. // because we've already removed the transport from the map above.
  63. if err := t.Close(); err != nil {
  64. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  65. }
  66. }
  67. }
  68. // Size returns the number of transports in the network.
  69. func (n *MemoryNetwork) Size() int {
  70. return len(n.transports)
  71. }
  72. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  73. // communicate between endpoints. It is primarily meant for testing.
  74. //
  75. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  76. // a different endpoint, both transports must be in the same MemoryNetwork.
  77. type MemoryTransport struct {
  78. logger log.Logger
  79. network *MemoryNetwork
  80. nodeID NodeID
  81. bufferSize int
  82. acceptCh chan *MemoryConnection
  83. closeCh chan struct{}
  84. closeOnce sync.Once
  85. }
  86. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  87. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  88. func newMemoryTransport(network *MemoryNetwork, nodeID NodeID) *MemoryTransport {
  89. return &MemoryTransport{
  90. logger: network.logger.With("local", nodeID),
  91. network: network,
  92. nodeID: nodeID,
  93. bufferSize: network.bufferSize,
  94. acceptCh: make(chan *MemoryConnection),
  95. closeCh: make(chan struct{}),
  96. }
  97. }
  98. // String implements Transport.
  99. func (t *MemoryTransport) String() string {
  100. return string(MemoryProtocol)
  101. }
  102. // Protocols implements Transport.
  103. func (t *MemoryTransport) Protocols() []Protocol {
  104. return []Protocol{MemoryProtocol}
  105. }
  106. // Endpoints implements Transport.
  107. func (t *MemoryTransport) Endpoints() []Endpoint {
  108. select {
  109. case <-t.closeCh:
  110. return []Endpoint{}
  111. default:
  112. return []Endpoint{{
  113. Protocol: MemoryProtocol,
  114. Path: string(t.nodeID),
  115. }}
  116. }
  117. }
  118. // Accept implements Transport.
  119. func (t *MemoryTransport) Accept() (Connection, error) {
  120. select {
  121. case conn := <-t.acceptCh:
  122. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  123. return conn, nil
  124. case <-t.closeCh:
  125. return nil, io.EOF
  126. }
  127. }
  128. // Dial implements Transport.
  129. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  130. if endpoint.Protocol != MemoryProtocol {
  131. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  132. }
  133. if endpoint.Path == "" {
  134. return nil, errors.New("no path")
  135. }
  136. nodeID, err := NewNodeID(endpoint.Path)
  137. if err != nil {
  138. return nil, err
  139. }
  140. t.logger.Info("dialing peer", "remote", nodeID)
  141. peer := t.network.GetTransport(nodeID)
  142. if peer == nil {
  143. return nil, fmt.Errorf("unknown peer %q", nodeID)
  144. }
  145. inCh := make(chan memoryMessage, t.bufferSize)
  146. outCh := make(chan memoryMessage, t.bufferSize)
  147. closer := tmsync.NewCloser()
  148. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  149. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  150. select {
  151. case peer.acceptCh <- inConn:
  152. return outConn, nil
  153. case <-peer.closeCh:
  154. return nil, io.EOF
  155. case <-ctx.Done():
  156. return nil, ctx.Err()
  157. }
  158. }
  159. // Close implements Transport.
  160. func (t *MemoryTransport) Close() error {
  161. t.network.RemoveTransport(t.nodeID)
  162. t.closeOnce.Do(func() {
  163. close(t.closeCh)
  164. t.logger.Info("closed transport")
  165. })
  166. return nil
  167. }
  168. // MemoryConnection is an in-memory connection between two transport endpoints.
  169. type MemoryConnection struct {
  170. logger log.Logger
  171. localID NodeID
  172. remoteID NodeID
  173. receiveCh <-chan memoryMessage
  174. sendCh chan<- memoryMessage
  175. closer *tmsync.Closer
  176. }
  177. // memoryMessage is passed internally, containing either a message or handshake.
  178. type memoryMessage struct {
  179. channelID ChannelID
  180. message []byte
  181. // For handshakes.
  182. nodeInfo *NodeInfo
  183. pubKey crypto.PubKey
  184. }
  185. // newMemoryConnection creates a new MemoryConnection.
  186. func newMemoryConnection(
  187. logger log.Logger,
  188. localID NodeID,
  189. remoteID NodeID,
  190. receiveCh <-chan memoryMessage,
  191. sendCh chan<- memoryMessage,
  192. closer *tmsync.Closer,
  193. ) *MemoryConnection {
  194. return &MemoryConnection{
  195. logger: logger.With("remote", remoteID),
  196. localID: localID,
  197. remoteID: remoteID,
  198. receiveCh: receiveCh,
  199. sendCh: sendCh,
  200. closer: closer,
  201. }
  202. }
  203. // String implements Connection.
  204. func (c *MemoryConnection) String() string {
  205. return c.RemoteEndpoint().String()
  206. }
  207. // LocalEndpoint implements Connection.
  208. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  209. return Endpoint{
  210. Protocol: MemoryProtocol,
  211. Path: string(c.localID),
  212. }
  213. }
  214. // RemoteEndpoint implements Connection.
  215. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  216. return Endpoint{
  217. Protocol: MemoryProtocol,
  218. Path: string(c.remoteID),
  219. }
  220. }
  221. // Status implements Connection.
  222. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  223. return conn.ConnectionStatus{}
  224. }
  225. // Handshake implements Connection.
  226. func (c *MemoryConnection) Handshake(
  227. ctx context.Context,
  228. nodeInfo NodeInfo,
  229. privKey crypto.PrivKey,
  230. ) (NodeInfo, crypto.PubKey, error) {
  231. select {
  232. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  233. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  234. case <-c.closer.Done():
  235. return NodeInfo{}, nil, io.EOF
  236. case <-ctx.Done():
  237. return NodeInfo{}, nil, ctx.Err()
  238. }
  239. select {
  240. case msg := <-c.receiveCh:
  241. if msg.nodeInfo == nil {
  242. return NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  243. }
  244. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  245. return *msg.nodeInfo, msg.pubKey, nil
  246. case <-c.closer.Done():
  247. return NodeInfo{}, nil, io.EOF
  248. case <-ctx.Done():
  249. return NodeInfo{}, nil, ctx.Err()
  250. }
  251. }
  252. // ReceiveMessage implements Connection.
  253. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
  254. // Check close first, since channels are buffered. Otherwise, below select
  255. // may non-deterministically return non-error even when closed.
  256. select {
  257. case <-c.closer.Done():
  258. return 0, nil, io.EOF
  259. default:
  260. }
  261. select {
  262. case msg := <-c.receiveCh:
  263. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  264. return msg.channelID, msg.message, nil
  265. case <-c.closer.Done():
  266. return 0, nil, io.EOF
  267. }
  268. }
  269. // SendMessage implements Connection.
  270. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) {
  271. // Check close first, since channels are buffered. Otherwise, below select
  272. // may non-deterministically return non-error even when closed.
  273. select {
  274. case <-c.closer.Done():
  275. return false, io.EOF
  276. default:
  277. }
  278. select {
  279. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  280. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  281. return true, nil
  282. case <-c.closer.Done():
  283. return false, io.EOF
  284. }
  285. }
  286. // TrySendMessage implements Connection.
  287. func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) {
  288. // Check close first, since channels are buffered. Otherwise, below select
  289. // may non-deterministically return non-error even when closed.
  290. select {
  291. case <-c.closer.Done():
  292. return false, io.EOF
  293. default:
  294. }
  295. select {
  296. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  297. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  298. return true, nil
  299. case <-c.closer.Done():
  300. return false, io.EOF
  301. default:
  302. return false, nil
  303. }
  304. }
  305. // Close implements Connection.
  306. func (c *MemoryConnection) Close() error {
  307. select {
  308. case <-c.closer.Done():
  309. return nil
  310. default:
  311. c.closer.Close()
  312. c.logger.Info("closed connection")
  313. }
  314. return nil
  315. }
  316. // FlushClose implements Connection.
  317. func (c *MemoryConnection) FlushClose() error {
  318. return c.Close()
  319. }