You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

372 lines
9.6 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. "github.com/tendermint/tendermint/libs/log"
  11. tmsync "github.com/tendermint/tendermint/libs/sync"
  12. "github.com/tendermint/tendermint/p2p/conn"
  13. )
  14. const (
  15. MemoryProtocol Protocol = "memory"
  16. )
  17. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  18. // communicate between endpoints. It is primarily meant for testing.
  19. //
  20. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  21. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  22. type MemoryNetwork struct {
  23. logger log.Logger
  24. mtx sync.RWMutex
  25. transports map[NodeID]*MemoryTransport
  26. bufferSize int
  27. }
  28. // NewMemoryNetwork creates a new in-memory network.
  29. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  30. return &MemoryNetwork{
  31. bufferSize: bufferSize,
  32. logger: logger,
  33. transports: map[NodeID]*MemoryTransport{},
  34. }
  35. }
  36. // CreateTransport creates a new memory transport endpoint with the given node
  37. // ID and immediately begins listening on the address "memory:<id>". It panics
  38. // if the node ID is already in use (which is fine, since this is for tests).
  39. func (n *MemoryNetwork) CreateTransport(nodeID NodeID) *MemoryTransport {
  40. t := newMemoryTransport(n, nodeID)
  41. n.mtx.Lock()
  42. defer n.mtx.Unlock()
  43. if _, ok := n.transports[nodeID]; ok {
  44. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  45. }
  46. n.transports[nodeID] = t
  47. return t
  48. }
  49. // GetTransport looks up a transport in the network, returning nil if not found.
  50. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
  51. n.mtx.RLock()
  52. defer n.mtx.RUnlock()
  53. return n.transports[id]
  54. }
  55. // RemoveTransport removes a transport from the network and closes it.
  56. func (n *MemoryNetwork) RemoveTransport(id NodeID) {
  57. n.mtx.Lock()
  58. t, ok := n.transports[id]
  59. delete(n.transports, id)
  60. n.mtx.Unlock()
  61. if ok {
  62. // Close may recursively call RemoveTransport() again, but this is safe
  63. // because we've already removed the transport from the map above.
  64. if err := t.Close(); err != nil {
  65. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  66. }
  67. }
  68. }
  69. // Size returns the number of transports in the network.
  70. func (n *MemoryNetwork) Size() int {
  71. return len(n.transports)
  72. }
  73. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  74. // communicate between endpoints. It is primarily meant for testing.
  75. //
  76. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  77. // a different endpoint, both transports must be in the same MemoryNetwork.
  78. type MemoryTransport struct {
  79. logger log.Logger
  80. network *MemoryNetwork
  81. nodeID NodeID
  82. bufferSize int
  83. acceptCh chan *MemoryConnection
  84. closeCh chan struct{}
  85. closeOnce sync.Once
  86. }
  87. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  88. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  89. func newMemoryTransport(network *MemoryNetwork, nodeID NodeID) *MemoryTransport {
  90. return &MemoryTransport{
  91. logger: network.logger.With("local", nodeID),
  92. network: network,
  93. nodeID: nodeID,
  94. bufferSize: network.bufferSize,
  95. acceptCh: make(chan *MemoryConnection),
  96. closeCh: make(chan struct{}),
  97. }
  98. }
  99. // String implements Transport.
  100. func (t *MemoryTransport) String() string {
  101. return string(MemoryProtocol)
  102. }
  103. // Protocols implements Transport.
  104. func (t *MemoryTransport) Protocols() []Protocol {
  105. return []Protocol{MemoryProtocol}
  106. }
  107. // Endpoints implements Transport.
  108. func (t *MemoryTransport) Endpoints() []Endpoint {
  109. select {
  110. case <-t.closeCh:
  111. return []Endpoint{}
  112. default:
  113. return []Endpoint{{
  114. Protocol: MemoryProtocol,
  115. Path: string(t.nodeID),
  116. // An arbitrary IP and port is used in order for the pex
  117. // reactor to be able to send addresses to one another.
  118. IP: net.IPv4zero,
  119. Port: 0,
  120. }}
  121. }
  122. }
  123. // Accept implements Transport.
  124. func (t *MemoryTransport) Accept() (Connection, error) {
  125. select {
  126. case conn := <-t.acceptCh:
  127. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  128. return conn, nil
  129. case <-t.closeCh:
  130. return nil, io.EOF
  131. }
  132. }
  133. // Dial implements Transport.
  134. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  135. if endpoint.Protocol != MemoryProtocol {
  136. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  137. }
  138. if endpoint.Path == "" {
  139. return nil, errors.New("no path")
  140. }
  141. if err := endpoint.Validate(); err != nil {
  142. return nil, err
  143. }
  144. nodeID, err := NewNodeID(endpoint.Path)
  145. if err != nil {
  146. return nil, err
  147. }
  148. t.logger.Info("dialing peer", "remote", nodeID)
  149. peer := t.network.GetTransport(nodeID)
  150. if peer == nil {
  151. return nil, fmt.Errorf("unknown peer %q", nodeID)
  152. }
  153. inCh := make(chan memoryMessage, t.bufferSize)
  154. outCh := make(chan memoryMessage, t.bufferSize)
  155. closer := tmsync.NewCloser()
  156. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  157. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  158. select {
  159. case peer.acceptCh <- inConn:
  160. return outConn, nil
  161. case <-peer.closeCh:
  162. return nil, io.EOF
  163. case <-ctx.Done():
  164. return nil, ctx.Err()
  165. }
  166. }
  167. // Close implements Transport.
  168. func (t *MemoryTransport) Close() error {
  169. t.network.RemoveTransport(t.nodeID)
  170. t.closeOnce.Do(func() {
  171. close(t.closeCh)
  172. t.logger.Info("closed transport")
  173. })
  174. return nil
  175. }
  176. // MemoryConnection is an in-memory connection between two transport endpoints.
  177. type MemoryConnection struct {
  178. logger log.Logger
  179. localID NodeID
  180. remoteID NodeID
  181. receiveCh <-chan memoryMessage
  182. sendCh chan<- memoryMessage
  183. closer *tmsync.Closer
  184. }
  185. // memoryMessage is passed internally, containing either a message or handshake.
  186. type memoryMessage struct {
  187. channelID ChannelID
  188. message []byte
  189. // For handshakes.
  190. nodeInfo *NodeInfo
  191. pubKey crypto.PubKey
  192. }
  193. // newMemoryConnection creates a new MemoryConnection.
  194. func newMemoryConnection(
  195. logger log.Logger,
  196. localID NodeID,
  197. remoteID NodeID,
  198. receiveCh <-chan memoryMessage,
  199. sendCh chan<- memoryMessage,
  200. closer *tmsync.Closer,
  201. ) *MemoryConnection {
  202. return &MemoryConnection{
  203. logger: logger.With("remote", remoteID),
  204. localID: localID,
  205. remoteID: remoteID,
  206. receiveCh: receiveCh,
  207. sendCh: sendCh,
  208. closer: closer,
  209. }
  210. }
  211. // String implements Connection.
  212. func (c *MemoryConnection) String() string {
  213. return c.RemoteEndpoint().String()
  214. }
  215. // LocalEndpoint implements Connection.
  216. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  217. return Endpoint{
  218. Protocol: MemoryProtocol,
  219. Path: string(c.localID),
  220. }
  221. }
  222. // RemoteEndpoint implements Connection.
  223. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  224. return Endpoint{
  225. Protocol: MemoryProtocol,
  226. Path: string(c.remoteID),
  227. }
  228. }
  229. // Status implements Connection.
  230. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  231. return conn.ConnectionStatus{}
  232. }
  233. // Handshake implements Connection.
  234. func (c *MemoryConnection) Handshake(
  235. ctx context.Context,
  236. nodeInfo NodeInfo,
  237. privKey crypto.PrivKey,
  238. ) (NodeInfo, crypto.PubKey, error) {
  239. select {
  240. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  241. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  242. case <-c.closer.Done():
  243. return NodeInfo{}, nil, io.EOF
  244. case <-ctx.Done():
  245. return NodeInfo{}, nil, ctx.Err()
  246. }
  247. select {
  248. case msg := <-c.receiveCh:
  249. if msg.nodeInfo == nil {
  250. return NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  251. }
  252. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  253. return *msg.nodeInfo, msg.pubKey, nil
  254. case <-c.closer.Done():
  255. return NodeInfo{}, nil, io.EOF
  256. case <-ctx.Done():
  257. return NodeInfo{}, nil, ctx.Err()
  258. }
  259. }
  260. // ReceiveMessage implements Connection.
  261. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
  262. // Check close first, since channels are buffered. Otherwise, below select
  263. // may non-deterministically return non-error even when closed.
  264. select {
  265. case <-c.closer.Done():
  266. return 0, nil, io.EOF
  267. default:
  268. }
  269. select {
  270. case msg := <-c.receiveCh:
  271. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  272. return msg.channelID, msg.message, nil
  273. case <-c.closer.Done():
  274. return 0, nil, io.EOF
  275. }
  276. }
  277. // SendMessage implements Connection.
  278. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) {
  279. // Check close first, since channels are buffered. Otherwise, below select
  280. // may non-deterministically return non-error even when closed.
  281. select {
  282. case <-c.closer.Done():
  283. return false, io.EOF
  284. default:
  285. }
  286. select {
  287. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  288. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  289. return true, nil
  290. case <-c.closer.Done():
  291. return false, io.EOF
  292. }
  293. }
  294. // TrySendMessage implements Connection.
  295. func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) {
  296. // Check close first, since channels are buffered. Otherwise, below select
  297. // may non-deterministically return non-error even when closed.
  298. select {
  299. case <-c.closer.Done():
  300. return false, io.EOF
  301. default:
  302. }
  303. select {
  304. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  305. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  306. return true, nil
  307. case <-c.closer.Done():
  308. return false, io.EOF
  309. default:
  310. return false, nil
  311. }
  312. }
  313. // Close implements Connection.
  314. func (c *MemoryConnection) Close() error {
  315. select {
  316. case <-c.closer.Done():
  317. return nil
  318. default:
  319. c.closer.Close()
  320. c.logger.Info("closed connection")
  321. }
  322. return nil
  323. }
  324. // FlushClose implements Connection.
  325. func (c *MemoryConnection) FlushClose() error {
  326. return c.Close()
  327. }