You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

349 lines
9.2 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "github.com/tendermint/tendermint/crypto"
  10. "github.com/tendermint/tendermint/libs/log"
  11. "github.com/tendermint/tendermint/types"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses buffered Go channels to
  17. // communicate between endpoints. It is primarily meant for testing.
  18. //
  19. // Network endpoints are allocated via CreateTransport(), which takes a node ID,
  20. // and the endpoint is then immediately accessible via the URL "memory:<nodeID>".
  21. type MemoryNetwork struct {
  22. logger log.Logger
  23. mtx sync.RWMutex
  24. transports map[types.NodeID]*MemoryTransport
  25. bufferSize int
  26. }
  27. // NewMemoryNetwork creates a new in-memory network.
  28. func NewMemoryNetwork(logger log.Logger, bufferSize int) *MemoryNetwork {
  29. return &MemoryNetwork{
  30. bufferSize: bufferSize,
  31. logger: logger,
  32. transports: map[types.NodeID]*MemoryTransport{},
  33. }
  34. }
  35. // CreateTransport creates a new memory transport endpoint with the given node
  36. // ID and immediately begins listening on the address "memory:<id>". It panics
  37. // if the node ID is already in use (which is fine, since this is for tests).
  38. func (n *MemoryNetwork) CreateTransport(nodeID types.NodeID) *MemoryTransport {
  39. t := newMemoryTransport(n, nodeID)
  40. n.mtx.Lock()
  41. defer n.mtx.Unlock()
  42. if _, ok := n.transports[nodeID]; ok {
  43. panic(fmt.Sprintf("memory transport with node ID %q already exists", nodeID))
  44. }
  45. n.transports[nodeID] = t
  46. return t
  47. }
  48. // GetTransport looks up a transport in the network, returning nil if not found.
  49. func (n *MemoryNetwork) GetTransport(id types.NodeID) *MemoryTransport {
  50. n.mtx.RLock()
  51. defer n.mtx.RUnlock()
  52. return n.transports[id]
  53. }
  54. // RemoveTransport removes a transport from the network and closes it.
  55. func (n *MemoryNetwork) RemoveTransport(id types.NodeID) {
  56. n.mtx.Lock()
  57. t, ok := n.transports[id]
  58. delete(n.transports, id)
  59. n.mtx.Unlock()
  60. if ok {
  61. // Close may recursively call RemoveTransport() again, but this is safe
  62. // because we've already removed the transport from the map above.
  63. if err := t.Close(); err != nil {
  64. n.logger.Error("failed to close memory transport", "id", id, "err", err)
  65. }
  66. }
  67. }
  68. // Size returns the number of transports in the network.
  69. func (n *MemoryNetwork) Size() int {
  70. return len(n.transports)
  71. }
  72. // MemoryTransport is an in-memory transport that uses buffered Go channels to
  73. // communicate between endpoints. It is primarily meant for testing.
  74. //
  75. // New transports are allocated with MemoryNetwork.CreateTransport(). To contact
  76. // a different endpoint, both transports must be in the same MemoryNetwork.
  77. type MemoryTransport struct {
  78. logger log.Logger
  79. network *MemoryNetwork
  80. nodeID types.NodeID
  81. bufferSize int
  82. acceptCh chan *MemoryConnection
  83. closeCh chan struct{}
  84. closeFn func()
  85. }
  86. // newMemoryTransport creates a new MemoryTransport. This is for internal use by
  87. // MemoryNetwork, use MemoryNetwork.CreateTransport() instead.
  88. func newMemoryTransport(network *MemoryNetwork, nodeID types.NodeID) *MemoryTransport {
  89. once := &sync.Once{}
  90. closeCh := make(chan struct{})
  91. return &MemoryTransport{
  92. logger: network.logger.With("local", nodeID),
  93. network: network,
  94. nodeID: nodeID,
  95. bufferSize: network.bufferSize,
  96. acceptCh: make(chan *MemoryConnection),
  97. closeCh: closeCh,
  98. closeFn: func() { once.Do(func() { close(closeCh) }) },
  99. }
  100. }
  101. // String implements Transport.
  102. func (t *MemoryTransport) String() string {
  103. return string(MemoryProtocol)
  104. }
  105. func (*MemoryTransport) Listen(Endpoint) error { return nil }
  106. func (t *MemoryTransport) AddChannelDescriptors([]*ChannelDescriptor) {}
  107. // Protocols implements Transport.
  108. func (t *MemoryTransport) Protocols() []Protocol {
  109. return []Protocol{MemoryProtocol}
  110. }
  111. // Endpoints implements Transport.
  112. func (t *MemoryTransport) Endpoints() []Endpoint {
  113. if n := t.network.GetTransport(t.nodeID); n == nil {
  114. return []Endpoint{}
  115. }
  116. return []Endpoint{{
  117. Protocol: MemoryProtocol,
  118. Path: string(t.nodeID),
  119. // An arbitrary IP and port is used in order for the pex
  120. // reactor to be able to send addresses to one another.
  121. IP: net.IPv4zero,
  122. Port: 0,
  123. }}
  124. }
  125. // Accept implements Transport.
  126. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  127. select {
  128. case <-t.closeCh:
  129. return nil, io.EOF
  130. case conn := <-t.acceptCh:
  131. t.logger.Info("accepted connection", "remote", conn.RemoteEndpoint().Path)
  132. return conn, nil
  133. case <-ctx.Done():
  134. return nil, io.EOF
  135. }
  136. }
  137. // Dial implements Transport.
  138. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  139. if endpoint.Protocol != MemoryProtocol {
  140. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  141. }
  142. if endpoint.Path == "" {
  143. return nil, errors.New("no path")
  144. }
  145. if err := endpoint.Validate(); err != nil {
  146. return nil, err
  147. }
  148. nodeID, err := types.NewNodeID(endpoint.Path)
  149. if err != nil {
  150. return nil, err
  151. }
  152. t.logger.Info("dialing peer", "remote", nodeID)
  153. peer := t.network.GetTransport(nodeID)
  154. if peer == nil {
  155. return nil, fmt.Errorf("unknown peer %q", nodeID)
  156. }
  157. inCh := make(chan memoryMessage, t.bufferSize)
  158. outCh := make(chan memoryMessage, t.bufferSize)
  159. once := &sync.Once{}
  160. closeCh := make(chan struct{})
  161. closeFn := func() { once.Do(func() { close(closeCh) }) }
  162. outConn := newMemoryConnection(t.logger, t.nodeID, peer.nodeID, inCh, outCh)
  163. outConn.closeCh = closeCh
  164. outConn.closeFn = closeFn
  165. inConn := newMemoryConnection(peer.logger, peer.nodeID, t.nodeID, outCh, inCh)
  166. inConn.closeCh = closeCh
  167. inConn.closeFn = closeFn
  168. select {
  169. case peer.acceptCh <- inConn:
  170. return outConn, nil
  171. case <-ctx.Done():
  172. return nil, io.EOF
  173. }
  174. }
  175. // Close implements Transport.
  176. func (t *MemoryTransport) Close() error {
  177. t.network.RemoveTransport(t.nodeID)
  178. t.closeFn()
  179. return nil
  180. }
  181. // MemoryConnection is an in-memory connection between two transport endpoints.
  182. type MemoryConnection struct {
  183. logger log.Logger
  184. localID types.NodeID
  185. remoteID types.NodeID
  186. receiveCh <-chan memoryMessage
  187. sendCh chan<- memoryMessage
  188. closeFn func()
  189. closeCh <-chan struct{}
  190. }
  191. // memoryMessage is passed internally, containing either a message or handshake.
  192. type memoryMessage struct {
  193. channelID ChannelID
  194. message []byte
  195. // For handshakes.
  196. nodeInfo *types.NodeInfo
  197. pubKey crypto.PubKey
  198. }
  199. // newMemoryConnection creates a new MemoryConnection.
  200. func newMemoryConnection(
  201. logger log.Logger,
  202. localID types.NodeID,
  203. remoteID types.NodeID,
  204. receiveCh <-chan memoryMessage,
  205. sendCh chan<- memoryMessage,
  206. ) *MemoryConnection {
  207. return &MemoryConnection{
  208. logger: logger.With("remote", remoteID),
  209. localID: localID,
  210. remoteID: remoteID,
  211. receiveCh: receiveCh,
  212. sendCh: sendCh,
  213. }
  214. }
  215. // String implements Connection.
  216. func (c *MemoryConnection) String() string {
  217. return c.RemoteEndpoint().String()
  218. }
  219. // LocalEndpoint implements Connection.
  220. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  221. return Endpoint{
  222. Protocol: MemoryProtocol,
  223. Path: string(c.localID),
  224. }
  225. }
  226. // RemoteEndpoint implements Connection.
  227. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  228. return Endpoint{
  229. Protocol: MemoryProtocol,
  230. Path: string(c.remoteID),
  231. }
  232. }
  233. // Handshake implements Connection.
  234. func (c *MemoryConnection) Handshake(
  235. ctx context.Context,
  236. nodeInfo types.NodeInfo,
  237. privKey crypto.PrivKey,
  238. ) (types.NodeInfo, crypto.PubKey, error) {
  239. select {
  240. case c.sendCh <- memoryMessage{nodeInfo: &nodeInfo, pubKey: privKey.PubKey()}:
  241. c.logger.Debug("sent handshake", "nodeInfo", nodeInfo)
  242. case <-c.closeCh:
  243. return types.NodeInfo{}, nil, io.EOF
  244. case <-ctx.Done():
  245. return types.NodeInfo{}, nil, ctx.Err()
  246. }
  247. select {
  248. case msg := <-c.receiveCh:
  249. if msg.nodeInfo == nil {
  250. return types.NodeInfo{}, nil, errors.New("no NodeInfo in handshake")
  251. }
  252. c.logger.Debug("received handshake", "peerInfo", msg.nodeInfo)
  253. return *msg.nodeInfo, msg.pubKey, nil
  254. case <-c.closeCh:
  255. return types.NodeInfo{}, nil, io.EOF
  256. case <-ctx.Done():
  257. return types.NodeInfo{}, nil, ctx.Err()
  258. }
  259. }
  260. // ReceiveMessage implements Connection.
  261. func (c *MemoryConnection) ReceiveMessage(ctx context.Context) (ChannelID, []byte, error) {
  262. // Check close first, since channels are buffered. Otherwise, below select
  263. // may non-deterministically return non-error even when closed.
  264. select {
  265. case <-c.closeCh:
  266. return 0, nil, io.EOF
  267. case <-ctx.Done():
  268. return 0, nil, io.EOF
  269. default:
  270. }
  271. select {
  272. case msg := <-c.receiveCh:
  273. c.logger.Debug("received message", "chID", msg.channelID, "msg", msg.message)
  274. return msg.channelID, msg.message, nil
  275. case <-ctx.Done():
  276. return 0, nil, io.EOF
  277. case <-c.closeCh:
  278. return 0, nil, io.EOF
  279. }
  280. }
  281. // SendMessage implements Connection.
  282. func (c *MemoryConnection) SendMessage(ctx context.Context, chID ChannelID, msg []byte) error {
  283. // Check close first, since channels are buffered. Otherwise, below select
  284. // may non-deterministically return non-error even when closed.
  285. select {
  286. case <-c.closeCh:
  287. return io.EOF
  288. case <-ctx.Done():
  289. return io.EOF
  290. default:
  291. }
  292. select {
  293. case c.sendCh <- memoryMessage{channelID: chID, message: msg}:
  294. c.logger.Debug("sent message", "chID", chID, "msg", msg)
  295. return nil
  296. case <-ctx.Done():
  297. return io.EOF
  298. case <-c.closeCh:
  299. return io.EOF
  300. }
  301. }
  302. // Close implements Connection.
  303. func (c *MemoryConnection) Close() error { c.closeFn(); return nil }