You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

340 lines
9.0 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. "github.com/tendermint/tendermint/libs/log"
  11. "github.com/tendermint/tendermint/types"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  17. // communicate between endpoints. It is primarily meant for testing.
  18. //
  19. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  20. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  21. type MemoryNetwork struct {
  22. logger log.Logger
  23. mtx sync.RWMutex
  24. transports map[types.NodeID]*MemoryTransport
  25. bufferSize int
  26. }
  27. // NewMemoryNetwork creates a new in-memory network.
  28. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  29. return &MemoryNetwork{
  30. bufferSize: bufferSize,
  31. logger: logger,
  32. transports: map[types.NodeID]*MemoryTransport{},
  33. }
  34. }
  35. // CreateTransport creates a new memory transport endpoint with the given node
  36. // ID and immediately begins listening on the address "memory:<id>". It panics
  37. // if the node ID is already in use (which is fine, since this is for tests).
  38. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  39. t := newMemoryTransport(n, nodeID)
  40. n.mtx.Lock()
  41. defer n.mtx.Unlock()
  42. if _, ok := n.transports[nodeID]; ok {
  43. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  44. }
  45. n.transports[nodeID] = t
  46. return t
  47. }
  48. // GetTransport looks up a transport in the network, returning nil if not found.
  49. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  50. n.mtx.RLock()
  51. defer n.mtx.RUnlock()
  52. return n.transports[id]
  53. }
  54. // RemoveTransport removes a transport from the network and closes it.
  55. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  56. n.mtx.Lock()
  57. t, ok := n.transports[id]
  58. delete(n.transports, id)
  59. n.mtx.Unlock()
  60. if ok {
  61. // Close may recursively call RemoveTransport() again, but this is safe
  62. // because we've already removed the transport from the map above.
  63. if err := t.Close(); err != nil {
  64. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  65. }
  66. }
  67. }
  68. // Size returns the number of transports in the network.
  69. func (n *MemoryNetwork) Size() int {
  70. return len(n.transports)
  71. }
  72. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  73. // communicate between endpoints. It is primarily meant for testing.
  74. //
  75. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  76. // a different endpoint, both transports must be in the same MemoryNetwork.
  77. type MemoryTransport struct {
  78. logger log.Logger
  79. network *MemoryNetwork
  80. nodeID types.NodeID
  81. bufferSize int
  82. acceptCh chan *MemoryConnection
  83. }
  84. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  85. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  86. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  87. return &MemoryTransport{
  88. logger: network.logger.With("local", nodeID),
  89. network: network,
  90. nodeID: nodeID,
  91. bufferSize: network.bufferSize,
  92. acceptCh: make(chan *MemoryConnection),
  93. }
  94. }
  95. // String implements Transport.
  96. func (t *MemoryTransport) String() string {
  97. return string(MemoryProtocol)
  98. }
  99. func (*MemoryTransport) Listen(Endpoint) error { return nil }
  100. func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {}
  101. // Protocols implements Transport.
  102. func (t *MemoryTransport) Protocols() []Protocol {
  103. return []Protocol{MemoryProtocol}
  104. }
  105. // Endpoints implements Transport.
  106. func (t *MemoryTransport) Endpoints() []Endpoint {
  107. if n := t.network.GetTransport(t.nodeID); n == nil {
  108. return []Endpoint{}
  109. }
  110. return []Endpoint{{
  111. Protocol: MemoryProtocol,
  112. Path: string(t.nodeID),
  113. // An arbitrary IP and port is used in order for the pex
  114. // reactor to be able to send addresses to one another.
  115. IP: net.IPv4zero,
  116. Port: 0,
  117. }}
  118. }
  119. // Accept implements Transport.
  120. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  121. select {
  122. case conn := <-t.acceptCh:
  123. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  124. return conn, nil
  125. case <-ctx.Done():
  126. return nil, io.EOF
  127. }
  128. }
  129. // Dial implements Transport.
  130. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  131. if endpoint.Protocol != MemoryProtocol {
  132. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  133. }
  134. if endpoint.Path == "" {
  135. return nil, errors.New("no path")
  136. }
  137. if err := endpoint.Validate(); err != nil {
  138. return nil, err
  139. }
  140. nodeID, err := types.NewNodeID(endpoint.Path)
  141. if err != nil {
  142. return nil, err
  143. }
  144. t.logger.Info("dialing peer", "remote", nodeID)
  145. peer := t.network.GetTransport(nodeID)
  146. if peer == nil {
  147. return nil, fmt.Errorf("unknown peer %q", nodeID)
  148. }
  149. inCh := make(chan memoryMessage, t.bufferSize)
  150. outCh := make(chan memoryMessage, t.bufferSize)
  151. once := &sync.Once{}
  152. closeCh := make(chan struct{})
  153. closeFn := func() { once.Do(func() { close(closeCh) }) }
  154. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh)
  155. outConn.closeCh = closeCh
  156. outConn.closeFn = closeFn
  157. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh)
  158. inConn.closeCh = closeCh
  159. inConn.closeFn = closeFn
  160. select {
  161. case peer.acceptCh <- inConn:
  162. return outConn, nil
  163. case <-ctx.Done():
  164. return nil, io.EOF
  165. }
  166. }
  167. // Close implements Transport.
  168. func (t *MemoryTransport) Close() error {
  169. t.network.RemoveTransport(t.nodeID)
  170. return nil
  171. }
  172. // MemoryConnection is an in-memory connection between two transport endpoints.
  173. type MemoryConnection struct {
  174. logger log.Logger
  175. localID types.NodeID
  176. remoteID types.NodeID
  177. receiveCh <-chan memoryMessage
  178. sendCh chan<- memoryMessage
  179. closeFn func()
  180. closeCh <-chan struct{}
  181. }
  182. // memoryMessage is passed internally, containing either a message or handshake.
  183. type memoryMessage struct {
  184. channelID ChannelID
  185. message []byte
  186. // For handshakes.
  187. nodeInfo *types.NodeInfo
  188. pubKey crypto.PubKey
  189. }
  190. // newMemoryConnection creates a new MemoryConnection.
  191. func newMemoryConnection(
  192. logger log.Logger,
  193. localID types.NodeID,
  194. remoteID types.NodeID,
  195. receiveCh <-chan memoryMessage,
  196. sendCh chan<- memoryMessage,
  197. ) *MemoryConnection {
  198. return &MemoryConnection{
  199. logger: logger.With("remote", remoteID),
  200. localID: localID,
  201. remoteID: remoteID,
  202. receiveCh: receiveCh,
  203. sendCh: sendCh,
  204. }
  205. }
  206. // String implements Connection.
  207. func (c *MemoryConnection) String() string {
  208. return c.RemoteEndpoint().String()
  209. }
  210. // LocalEndpoint implements Connection.
  211. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  212. return Endpoint{
  213. Protocol: MemoryProtocol,
  214. Path: string(c.localID),
  215. }
  216. }
  217. // RemoteEndpoint implements Connection.
  218. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  219. return Endpoint{
  220. Protocol: MemoryProtocol,
  221. Path: string(c.remoteID),
  222. }
  223. }
  224. // Handshake implements Connection.
  225. func (c *MemoryConnection) Handshake(
  226. ctx context.Context,
  227. nodeInfo types.NodeInfo,
  228. privKey crypto.PrivKey,
  229. ) (types.NodeInfo, crypto.PubKey, error) {
  230. select {
  231. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  232. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  233. case <-c.closeCh:
  234. return types.NodeInfo{}, nil, io.EOF
  235. case <-ctx.Done():
  236. return types.NodeInfo{}, nil, ctx.Err()
  237. }
  238. select {
  239. case msg := <-c.receiveCh:
  240. if msg.nodeInfo == nil {
  241. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  242. }
  243. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  244. return *msg.nodeInfo, msg.pubKey, nil
  245. case <-c.closeCh:
  246. return types.NodeInfo{}, nil, io.EOF
  247. case <-ctx.Done():
  248. return types.NodeInfo{}, nil, ctx.Err()
  249. }
  250. }
  251. // ReceiveMessage implements Connection.
  252. func (c *MemoryConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
  253. // Check close first, since channels are buffered. Otherwise, below select
  254. // may non-deterministically return non-error even when closed.
  255. select {
  256. case <-c.closeCh:
  257. return 0, nil, io.EOF
  258. case <-ctx.Done():
  259. return 0, nil, io.EOF
  260. default:
  261. }
  262. select {
  263. case msg := <-c.receiveCh:
  264. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  265. return msg.channelID, msg.message, nil
  266. case <-ctx.Done():
  267. return 0, nil, io.EOF
  268. case <-c.closeCh:
  269. return 0, nil, io.EOF
  270. }
  271. }
  272. // SendMessage implements Connection.
  273. func (c *MemoryConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
  274. // Check close first, since channels are buffered. Otherwise, below select
  275. // may non-deterministically return non-error even when closed.
  276. select {
  277. case <-c.closeCh:
  278. return io.EOF
  279. case <-ctx.Done():
  280. return io.EOF
  281. default:
  282. }
  283. select {
  284. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  285. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  286. return nil
  287. case <-ctx.Done():
  288. return io.EOF
  289. case <-c.closeCh:
  290. return io.EOF
  291. }
  292. }
  293. // Close implements Connection.
  294. func (c *MemoryConnection) Close() error { c.closeFn(); return nil }