You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

373 lines
9.7 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. tmsync "github.com/tendermint/tendermint/internal/libs/sync"
  11. "github.com/tendermint/tendermint/internal/p2p/conn"
  12. "github.com/tendermint/tendermint/libs/log"
  13. "github.com/tendermint/tendermint/types"
  14. )
  15. const (
  16. MemoryProtocol Protocol = "memory"
  17. )
  18. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  19. // communicate between endpoints. It is primarily meant for testing.
  20. //
  21. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  22. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  23. type MemoryNetwork struct {
  24. logger log.Logger
  25. mtx sync.RWMutex
  26. transports map[types.NodeID]*MemoryTransport
  27. bufferSize int
  28. }
  29. // NewMemoryNetwork creates a new in-memory network.
  30. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  31. return &MemoryNetwork{
  32. bufferSize: bufferSize,
  33. logger: logger,
  34. transports: map[types.NodeID]*MemoryTransport{},
  35. }
  36. }
  37. // CreateTransport creates a new memory transport endpoint with the given node
  38. // ID and immediately begins listening on the address "memory:<id>". It panics
  39. // if the node ID is already in use (which is fine, since this is for tests).
  40. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  41. t := newMemoryTransport(n, nodeID)
  42. n.mtx.Lock()
  43. defer n.mtx.Unlock()
  44. if _, ok := n.transports[nodeID]; ok {
  45. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  46. }
  47. n.transports[nodeID] = t
  48. return t
  49. }
  50. // GetTransport looks up a transport in the network, returning nil if not found.
  51. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  52. n.mtx.RLock()
  53. defer n.mtx.RUnlock()
  54. return n.transports[id]
  55. }
  56. // RemoveTransport removes a transport from the network and closes it.
  57. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  58. n.mtx.Lock()
  59. t, ok := n.transports[id]
  60. delete(n.transports, id)
  61. n.mtx.Unlock()
  62. if ok {
  63. // Close may recursively call RemoveTransport() again, but this is safe
  64. // because we've already removed the transport from the map above.
  65. if err := t.Close(); err != nil {
  66. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  67. }
  68. }
  69. }
  70. // Size returns the number of transports in the network.
  71. func (n *MemoryNetwork) Size() int {
  72. return len(n.transports)
  73. }
  74. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  75. // communicate between endpoints. It is primarily meant for testing.
  76. //
  77. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  78. // a different endpoint, both transports must be in the same MemoryNetwork.
  79. type MemoryTransport struct {
  80. logger log.Logger
  81. network *MemoryNetwork
  82. nodeID types.NodeID
  83. bufferSize int
  84. acceptCh chan *MemoryConnection
  85. closeCh chan struct{}
  86. closeOnce sync.Once
  87. }
  88. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  89. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  90. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  91. return &MemoryTransport{
  92. logger: network.logger.With("local", nodeID),
  93. network: network,
  94. nodeID: nodeID,
  95. bufferSize: network.bufferSize,
  96. acceptCh: make(chan *MemoryConnection),
  97. closeCh: make(chan struct{}),
  98. }
  99. }
  100. // String implements Transport.
  101. func (t *MemoryTransport) String() string {
  102. return string(MemoryProtocol)
  103. }
  104. // Protocols implements Transport.
  105. func (t *MemoryTransport) Protocols() []Protocol {
  106. return []Protocol{MemoryProtocol}
  107. }
  108. // Endpoints implements Transport.
  109. func (t *MemoryTransport) Endpoints() []Endpoint {
  110. select {
  111. case <-t.closeCh:
  112. return []Endpoint{}
  113. default:
  114. return []Endpoint{{
  115. Protocol: MemoryProtocol,
  116. Path: string(t.nodeID),
  117. // An arbitrary IP and port is used in order for the pex
  118. // reactor to be able to send addresses to one another.
  119. IP: net.IPv4zero,
  120. Port: 0,
  121. }}
  122. }
  123. }
  124. // Accept implements Transport.
  125. func (t *MemoryTransport) Accept() (Connection, error) {
  126. select {
  127. case conn := <-t.acceptCh:
  128. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  129. return conn, nil
  130. case <-t.closeCh:
  131. return nil, io.EOF
  132. }
  133. }
  134. // Dial implements Transport.
  135. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  136. if endpoint.Protocol != MemoryProtocol {
  137. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  138. }
  139. if endpoint.Path == "" {
  140. return nil, errors.New("no path")
  141. }
  142. if err := endpoint.Validate(); err != nil {
  143. return nil, err
  144. }
  145. nodeID, err := types.NewNodeID(endpoint.Path)
  146. if err != nil {
  147. return nil, err
  148. }
  149. t.logger.Info("dialing peer", "remote", nodeID)
  150. peer := t.network.GetTransport(nodeID)
  151. if peer == nil {
  152. return nil, fmt.Errorf("unknown peer %q", nodeID)
  153. }
  154. inCh := make(chan memoryMessage, t.bufferSize)
  155. outCh := make(chan memoryMessage, t.bufferSize)
  156. closer := tmsync.NewCloser()
  157. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh, closer)
  158. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh, closer)
  159. select {
  160. case peer.acceptCh <- inConn:
  161. return outConn, nil
  162. case <-peer.closeCh:
  163. return nil, io.EOF
  164. case <-ctx.Done():
  165. return nil, ctx.Err()
  166. }
  167. }
  168. // Close implements Transport.
  169. func (t *MemoryTransport) Close() error {
  170. t.network.RemoveTransport(t.nodeID)
  171. t.closeOnce.Do(func() {
  172. close(t.closeCh)
  173. t.logger.Info("closed transport")
  174. })
  175. return nil
  176. }
  177. // MemoryConnection is an in-memory connection between two transport endpoints.
  178. type MemoryConnection struct {
  179. logger log.Logger
  180. localID types.NodeID
  181. remoteID types.NodeID
  182. receiveCh <-chan memoryMessage
  183. sendCh chan<- memoryMessage
  184. closer *tmsync.Closer
  185. }
  186. // memoryMessage is passed internally, containing either a message or handshake.
  187. type memoryMessage struct {
  188. channelID ChannelID
  189. message []byte
  190. // For handshakes.
  191. nodeInfo *types.NodeInfo
  192. pubKey crypto.PubKey
  193. }
  194. // newMemoryConnection creates a new MemoryConnection.
  195. func newMemoryConnection(
  196. logger log.Logger,
  197. localID types.NodeID,
  198. remoteID types.NodeID,
  199. receiveCh <-chan memoryMessage,
  200. sendCh chan<- memoryMessage,
  201. closer *tmsync.Closer,
  202. ) *MemoryConnection {
  203. return &MemoryConnection{
  204. logger: logger.With("remote", remoteID),
  205. localID: localID,
  206. remoteID: remoteID,
  207. receiveCh: receiveCh,
  208. sendCh: sendCh,
  209. closer: closer,
  210. }
  211. }
  212. // String implements Connection.
  213. func (c *MemoryConnection) String() string {
  214. return c.RemoteEndpoint().String()
  215. }
  216. // LocalEndpoint implements Connection.
  217. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  218. return Endpoint{
  219. Protocol: MemoryProtocol,
  220. Path: string(c.localID),
  221. }
  222. }
  223. // RemoteEndpoint implements Connection.
  224. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  225. return Endpoint{
  226. Protocol: MemoryProtocol,
  227. Path: string(c.remoteID),
  228. }
  229. }
  230. // Status implements Connection.
  231. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  232. return conn.ConnectionStatus{}
  233. }
  234. // Handshake implements Connection.
  235. func (c *MemoryConnection) Handshake(
  236. ctx context.Context,
  237. nodeInfo types.NodeInfo,
  238. privKey crypto.PrivKey,
  239. ) (types.NodeInfo, crypto.PubKey, error) {
  240. select {
  241. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  242. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  243. case <-c.closer.Done():
  244. return types.NodeInfo{}, nil, io.EOF
  245. case <-ctx.Done():
  246. return types.NodeInfo{}, nil, ctx.Err()
  247. }
  248. select {
  249. case msg := <-c.receiveCh:
  250. if msg.nodeInfo == nil {
  251. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  252. }
  253. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  254. return *msg.nodeInfo, msg.pubKey, nil
  255. case <-c.closer.Done():
  256. return types.NodeInfo{}, nil, io.EOF
  257. case <-ctx.Done():
  258. return types.NodeInfo{}, nil, ctx.Err()
  259. }
  260. }
  261. // ReceiveMessage implements Connection.
  262. func (c *MemoryConnection) ReceiveMessage() (ChannelID, []byte, error) {
  263. // Check close first, since channels are buffered. Otherwise, below select
  264. // may non-deterministically return non-error even when closed.
  265. select {
  266. case <-c.closer.Done():
  267. return 0, nil, io.EOF
  268. default:
  269. }
  270. select {
  271. case msg := <-c.receiveCh:
  272. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  273. return msg.channelID, msg.message, nil
  274. case <-c.closer.Done():
  275. return 0, nil, io.EOF
  276. }
  277. }
  278. // SendMessage implements Connection.
  279. func (c *MemoryConnection) SendMessage(chID ChannelID, msg []byte) (bool, error) {
  280. // Check close first, since channels are buffered. Otherwise, below select
  281. // may non-deterministically return non-error even when closed.
  282. select {
  283. case <-c.closer.Done():
  284. return false, io.EOF
  285. default:
  286. }
  287. select {
  288. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  289. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  290. return true, nil
  291. case <-c.closer.Done():
  292. return false, io.EOF
  293. }
  294. }
  295. // TrySendMessage implements Connection.
  296. func (c *MemoryConnection) TrySendMessage(chID ChannelID, msg []byte) (bool, error) {
  297. // Check close first, since channels are buffered. Otherwise, below select
  298. // may non-deterministically return non-error even when closed.
  299. select {
  300. case <-c.closer.Done():
  301. return false, io.EOF
  302. default:
  303. }
  304. select {
  305. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  306. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  307. return true, nil
  308. case <-c.closer.Done():
  309. return false, io.EOF
  310. default:
  311. return false, nil
  312. }
  313. }
  314. // Close implements Connection.
  315. func (c *MemoryConnection) Close() error {
  316. select {
  317. case <-c.closer.Done():
  318. return nil
  319. default:
  320. c.closer.Close()
  321. c.logger.Info("closed connection")
  322. }
  323. return nil
  324. }
  325. // FlushClose implements Connection.
  326. func (c *MemoryConnection) FlushClose() error {
  327. return c.Close()
  328. }