You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

394 lines
10 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto"
  9. "github.com/tendermint/tendermint/crypto/ed25519"
  10. "github.com/tendermint/tendermint/libs/log"
  11. "github.com/tendermint/tendermint/p2p/conn"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses Go channels to communicate
  17. // between endpoints. Transport endpoints are created with CreateTransport. It
  18. // is primarily used for testing.
  19. type MemoryNetwork struct {
  20. logger log.Logger
  21. mtx sync.RWMutex
  22. transports map[NodeID]*MemoryTransport
  23. }
  24. // NewMemoryNetwork creates a new in-memory network.
  25. func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
  26. return &MemoryNetwork{
  27. logger: logger,
  28. transports: map[NodeID]*MemoryTransport{},
  29. }
  30. }
  31. // CreateTransport creates a new memory transport and endpoint for the given
  32. // NodeInfo and private key. Use GenerateTransport() to autogenerate a random
  33. // key and node info.
  34. //
  35. // The transport immediately begins listening on the endpoint "memory:<id>", and
  36. // can be accessed by other transports in the same memory network.
  37. func (n *MemoryNetwork) CreateTransport(
  38. nodeInfo NodeInfo,
  39. privKey crypto.PrivKey,
  40. ) (*MemoryTransport, error) {
  41. nodeID := nodeInfo.NodeID
  42. if nodeID == "" {
  43. return nil, errors.New("no node ID")
  44. }
  45. t := newMemoryTransport(n, nodeInfo, privKey)
  46. n.mtx.Lock()
  47. defer n.mtx.Unlock()
  48. if _, ok := n.transports[nodeID]; ok {
  49. return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
  50. }
  51. n.transports[nodeID] = t
  52. return t, nil
  53. }
  54. // GenerateTransport generates a new transport endpoint by generating a random
  55. // private key and node info. The endpoint address can be obtained via
  56. // Transport.Endpoints().
  57. func (n *MemoryNetwork) GenerateTransport() *MemoryTransport {
  58. privKey := ed25519.GenPrivKey()
  59. nodeID := NodeIDFromPubKey(privKey.PubKey())
  60. nodeInfo := NodeInfo{
  61. NodeID: nodeID,
  62. ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID),
  63. }
  64. t, err := n.CreateTransport(nodeInfo, privKey)
  65. if err != nil {
  66. // GenerateTransport is only used for testing, and the likelihood of
  67. // generating a duplicate node ID is very low, so we'll panic.
  68. panic(err)
  69. }
  70. return t
  71. }
  72. // GetTransport looks up a transport in the network, returning nil if not found.
  73. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
  74. n.mtx.RLock()
  75. defer n.mtx.RUnlock()
  76. return n.transports[id]
  77. }
  78. // RemoveTransport removes a transport from the network and closes it.
  79. func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
  80. n.mtx.Lock()
  81. t, ok := n.transports[id]
  82. delete(n.transports, id)
  83. n.mtx.Unlock()
  84. if ok {
  85. // Close may recursively call RemoveTransport() again, but this is safe
  86. // because we've already removed the transport from the map above.
  87. return t.Close()
  88. }
  89. return nil
  90. }
  91. // MemoryTransport is an in-memory transport that's primarily meant for testing.
  92. // It communicates between endpoints using Go channels. To dial a different
  93. // endpoint, both endpoints/transports must be in the same MemoryNetwork.
  94. type MemoryTransport struct {
  95. network *MemoryNetwork
  96. nodeInfo NodeInfo
  97. privKey crypto.PrivKey
  98. logger log.Logger
  99. acceptCh chan *MemoryConnection
  100. closeCh chan struct{}
  101. closeOnce sync.Once
  102. }
  103. // newMemoryTransport creates a new in-memory transport in the given network.
  104. // Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
  105. // to create transports, this is for internal use by MemoryNetwork.
  106. func newMemoryTransport(
  107. network *MemoryNetwork,
  108. nodeInfo NodeInfo,
  109. privKey crypto.PrivKey,
  110. ) *MemoryTransport {
  111. return &MemoryTransport{
  112. network: network,
  113. nodeInfo: nodeInfo,
  114. privKey: privKey,
  115. logger: network.logger.With("local",
  116. fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.NodeID)),
  117. acceptCh: make(chan *MemoryConnection),
  118. closeCh: make(chan struct{}),
  119. }
  120. }
  121. // Accept implements Transport.
  122. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  123. select {
  124. case conn := <-t.acceptCh:
  125. t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
  126. return conn, nil
  127. case <-t.closeCh:
  128. return nil, ErrTransportClosed{}
  129. case <-ctx.Done():
  130. return nil, ctx.Err()
  131. }
  132. }
  133. // Dial implements Transport.
  134. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  135. if endpoint.Protocol != MemoryProtocol {
  136. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  137. }
  138. if endpoint.Path == "" {
  139. return nil, errors.New("no path")
  140. }
  141. if endpoint.PeerID == "" {
  142. return nil, errors.New("no peer ID")
  143. }
  144. t.logger.Info("dialing peer", "remote", endpoint)
  145. peerTransport := t.network.GetTransport(NodeID(endpoint.Path))
  146. if peerTransport == nil {
  147. return nil, fmt.Errorf("unknown peer %q", endpoint.Path)
  148. }
  149. inCh := make(chan memoryMessage, 1)
  150. outCh := make(chan memoryMessage, 1)
  151. closeCh := make(chan struct{})
  152. closeOnce := sync.Once{}
  153. closer := func() bool {
  154. closed := false
  155. closeOnce.Do(func() {
  156. close(closeCh)
  157. closed = true
  158. })
  159. return closed
  160. }
  161. outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer)
  162. inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer)
  163. select {
  164. case peerTransport.acceptCh <- inConn:
  165. return outConn, nil
  166. case <-peerTransport.closeCh:
  167. return nil, ErrTransportClosed{}
  168. case <-ctx.Done():
  169. return nil, ctx.Err()
  170. }
  171. }
  172. // DialAccept is a convenience function that dials a peer MemoryTransport and
  173. // returns both ends of the connection (A to B and B to A).
  174. func (t *MemoryTransport) DialAccept(
  175. ctx context.Context,
  176. peer *MemoryTransport,
  177. ) (Connection, Connection, error) {
  178. endpoints := peer.Endpoints()
  179. if len(endpoints) == 0 {
  180. return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.NodeID)
  181. }
  182. acceptCh := make(chan Connection, 1)
  183. errCh := make(chan error, 1)
  184. go func() {
  185. conn, err := peer.Accept(ctx)
  186. errCh <- err
  187. acceptCh <- conn
  188. }()
  189. outConn, err := t.Dial(ctx, endpoints[0])
  190. if err != nil {
  191. return nil, nil, err
  192. }
  193. if err = <-errCh; err != nil {
  194. return nil, nil, err
  195. }
  196. inConn := <-acceptCh
  197. return outConn, inConn, nil
  198. }
  199. // Close implements Transport.
  200. func (t *MemoryTransport) Close() error {
  201. err := t.network.RemoveTransport(t.nodeInfo.NodeID)
  202. t.closeOnce.Do(func() {
  203. close(t.closeCh)
  204. })
  205. t.logger.Info("stopped accepting connections")
  206. return err
  207. }
  208. // Endpoints implements Transport.
  209. func (t *MemoryTransport) Endpoints() []Endpoint {
  210. select {
  211. case <-t.closeCh:
  212. return []Endpoint{}
  213. default:
  214. return []Endpoint{{
  215. Protocol: MemoryProtocol,
  216. PeerID: t.nodeInfo.NodeID,
  217. Path: string(t.nodeInfo.NodeID),
  218. }}
  219. }
  220. }
  221. // SetChannelDescriptors implements Transport.
  222. func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) {
  223. }
  224. // MemoryConnection is an in-memory connection between two transports (nodes).
  225. type MemoryConnection struct {
  226. logger log.Logger
  227. local *MemoryTransport
  228. remote *MemoryTransport
  229. receiveCh <-chan memoryMessage
  230. sendCh chan<- memoryMessage
  231. closeCh <-chan struct{}
  232. close func() bool
  233. }
  234. // memoryMessage is used to pass messages internally in the connection.
  235. type memoryMessage struct {
  236. channel byte
  237. message []byte
  238. }
  239. // newMemoryConnection creates a new MemoryConnection. It takes all channels
  240. // (including the closeCh signal channel) on construction, such that they can be
  241. // shared between both ends of the connection.
  242. func newMemoryConnection(
  243. local *MemoryTransport,
  244. remote *MemoryTransport,
  245. receiveCh <-chan memoryMessage,
  246. sendCh chan<- memoryMessage,
  247. closeCh <-chan struct{},
  248. close func() bool,
  249. ) *MemoryConnection {
  250. c := &MemoryConnection{
  251. local: local,
  252. remote: remote,
  253. receiveCh: receiveCh,
  254. sendCh: sendCh,
  255. closeCh: closeCh,
  256. close: close,
  257. }
  258. c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
  259. return c
  260. }
  261. // ReceiveMessage implements Connection.
  262. func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
  263. // check close first, since channels are buffered
  264. select {
  265. case <-c.closeCh:
  266. return 0, nil, io.EOF
  267. default:
  268. }
  269. select {
  270. case msg := <-c.receiveCh:
  271. c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
  272. return msg.channel, msg.message, nil
  273. case <-c.closeCh:
  274. return 0, nil, io.EOF
  275. }
  276. }
  277. // SendMessage implements Connection.
  278. func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
  279. // check close first, since channels are buffered
  280. select {
  281. case <-c.closeCh:
  282. return false, io.EOF
  283. default:
  284. }
  285. select {
  286. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  287. c.logger.Debug("sent message", "channel", chID, "message", msg)
  288. return true, nil
  289. case <-c.closeCh:
  290. return false, io.EOF
  291. }
  292. }
  293. // TrySendMessage implements Connection.
  294. func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
  295. // check close first, since channels are buffered
  296. select {
  297. case <-c.closeCh:
  298. return false, io.EOF
  299. default:
  300. }
  301. select {
  302. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  303. c.logger.Debug("sent message", "channel", chID, "message", msg)
  304. return true, nil
  305. case <-c.closeCh:
  306. return false, io.EOF
  307. default:
  308. return false, nil
  309. }
  310. }
  311. // Close closes the connection.
  312. func (c *MemoryConnection) Close() error {
  313. if c.close() {
  314. c.logger.Info("closed connection")
  315. }
  316. return nil
  317. }
  318. // FlushClose flushes all pending sends and then closes the connection.
  319. func (c *MemoryConnection) FlushClose() error {
  320. return c.Close()
  321. }
  322. // LocalEndpoint returns the local endpoint for the connection.
  323. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  324. return Endpoint{
  325. PeerID: c.local.nodeInfo.NodeID,
  326. Protocol: MemoryProtocol,
  327. Path: string(c.local.nodeInfo.NodeID),
  328. }
  329. }
  330. // RemoteEndpoint returns the remote endpoint for the connection.
  331. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  332. return Endpoint{
  333. PeerID: c.remote.nodeInfo.NodeID,
  334. Protocol: MemoryProtocol,
  335. Path: string(c.remote.nodeInfo.NodeID),
  336. }
  337. }
  338. // PubKey returns the remote peer's public key.
  339. func (c *MemoryConnection) PubKey() crypto.PubKey {
  340. return c.remote.privKey.PubKey()
  341. }
  342. // NodeInfo returns the remote peer's node info.
  343. func (c *MemoryConnection) NodeInfo() NodeInfo {
  344. return c.remote.nodeInfo
  345. }
  346. // Status returns the current connection status.
  347. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  348. return conn.ConnectionStatus{}
  349. }