You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

388 lines
10 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto"
  9. "github.com/tendermint/tendermint/crypto/ed25519"
  10. "github.com/tendermint/tendermint/libs/log"
  11. "github.com/tendermint/tendermint/p2p/conn"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses Go channels to communicate
  17. // between endpoints. Transport endpoints are created with CreateTransport. It
  18. // is primarily used for testing.
  19. type MemoryNetwork struct {
  20. logger log.Logger
  21. mtx sync.RWMutex
  22. transports map[NodeID]*MemoryTransport
  23. }
  24. // NewMemoryNetwork creates a new in-memory network.
  25. func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
  26. return &MemoryNetwork{
  27. logger: logger,
  28. transports: map[NodeID]*MemoryTransport{},
  29. }
  30. }
  31. // CreateTransport creates a new memory transport and endpoint for the given
  32. // NodeInfo and private key. Use GenerateTransport() to autogenerate a random
  33. // key and node info.
  34. //
  35. // The transport immediately begins listening on the endpoint "memory:<id>", and
  36. // can be accessed by other transports in the same memory network.
  37. func (n *MemoryNetwork) CreateTransport(
  38. nodeInfo NodeInfo,
  39. privKey crypto.PrivKey,
  40. ) (*MemoryTransport, error) {
  41. nodeID := nodeInfo.NodeID
  42. if nodeID == "" {
  43. return nil, errors.New("no node ID")
  44. }
  45. t := newMemoryTransport(n, nodeInfo, privKey)
  46. n.mtx.Lock()
  47. defer n.mtx.Unlock()
  48. if _, ok := n.transports[nodeID]; ok {
  49. return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
  50. }
  51. n.transports[nodeID] = t
  52. return t, nil
  53. }
  54. // GenerateTransport generates a new transport endpoint by generating a random
  55. // private key and node info. The endpoint address can be obtained via
  56. // Transport.Endpoints().
  57. func (n *MemoryNetwork) GenerateTransport() *MemoryTransport {
  58. privKey := ed25519.GenPrivKey()
  59. nodeID := NodeIDFromPubKey(privKey.PubKey())
  60. nodeInfo := NodeInfo{
  61. NodeID: nodeID,
  62. ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID),
  63. }
  64. t, err := n.CreateTransport(nodeInfo, privKey)
  65. if err != nil {
  66. // GenerateTransport is only used for testing, and the likelihood of
  67. // generating a duplicate node ID is very low, so we'll panic.
  68. panic(err)
  69. }
  70. return t
  71. }
  72. // GetTransport looks up a transport in the network, returning nil if not found.
  73. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
  74. n.mtx.RLock()
  75. defer n.mtx.RUnlock()
  76. return n.transports[id]
  77. }
  78. // RemoveTransport removes a transport from the network and closes it.
  79. func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
  80. n.mtx.Lock()
  81. t, ok := n.transports[id]
  82. delete(n.transports, id)
  83. n.mtx.Unlock()
  84. if ok {
  85. // Close may recursively call RemoveTransport() again, but this is safe
  86. // because we've already removed the transport from the map above.
  87. return t.Close()
  88. }
  89. return nil
  90. }
  91. // MemoryTransport is an in-memory transport that's primarily meant for testing.
  92. // It communicates between endpoints using Go channels. To dial a different
  93. // endpoint, both endpoints/transports must be in the same MemoryNetwork.
  94. type MemoryTransport struct {
  95. network *MemoryNetwork
  96. nodeInfo NodeInfo
  97. privKey crypto.PrivKey
  98. logger log.Logger
  99. acceptCh chan *MemoryConnection
  100. closeCh chan struct{}
  101. closeOnce sync.Once
  102. }
  103. // newMemoryTransport creates a new in-memory transport in the given network.
  104. // Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
  105. // to create transports, this is for internal use by MemoryNetwork.
  106. func newMemoryTransport(
  107. network *MemoryNetwork,
  108. nodeInfo NodeInfo,
  109. privKey crypto.PrivKey,
  110. ) *MemoryTransport {
  111. return &MemoryTransport{
  112. network: network,
  113. nodeInfo: nodeInfo,
  114. privKey: privKey,
  115. logger: network.logger.With("local",
  116. fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.NodeID)),
  117. acceptCh: make(chan *MemoryConnection),
  118. closeCh: make(chan struct{}),
  119. }
  120. }
  121. // Accept implements Transport.
  122. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  123. select {
  124. case conn := <-t.acceptCh:
  125. t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
  126. return conn, nil
  127. case <-t.closeCh:
  128. return nil, ErrTransportClosed{}
  129. case <-ctx.Done():
  130. return nil, ctx.Err()
  131. }
  132. }
  133. // Dial implements Transport.
  134. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  135. if endpoint.Protocol != MemoryProtocol {
  136. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  137. }
  138. if endpoint.PeerID == "" {
  139. return nil, errors.New("no peer ID")
  140. }
  141. t.logger.Info("dialing peer", "remote", endpoint)
  142. peerTransport := t.network.GetTransport(endpoint.PeerID)
  143. if peerTransport == nil {
  144. return nil, fmt.Errorf("unknown peer %q", endpoint.PeerID)
  145. }
  146. inCh := make(chan memoryMessage, 1)
  147. outCh := make(chan memoryMessage, 1)
  148. closeCh := make(chan struct{})
  149. closeOnce := sync.Once{}
  150. closer := func() bool {
  151. closed := false
  152. closeOnce.Do(func() {
  153. close(closeCh)
  154. closed = true
  155. })
  156. return closed
  157. }
  158. outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer)
  159. inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer)
  160. select {
  161. case peerTransport.acceptCh <- inConn:
  162. return outConn, nil
  163. case <-peerTransport.closeCh:
  164. return nil, ErrTransportClosed{}
  165. case <-ctx.Done():
  166. return nil, ctx.Err()
  167. }
  168. }
  169. // DialAccept is a convenience function that dials a peer MemoryTransport and
  170. // returns both ends of the connection (A to B and B to A).
  171. func (t *MemoryTransport) DialAccept(
  172. ctx context.Context,
  173. peer *MemoryTransport,
  174. ) (Connection, Connection, error) {
  175. endpoints := peer.Endpoints()
  176. if len(endpoints) == 0 {
  177. return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.NodeID)
  178. }
  179. acceptCh := make(chan Connection, 1)
  180. errCh := make(chan error, 1)
  181. go func() {
  182. conn, err := peer.Accept(ctx)
  183. errCh <- err
  184. acceptCh <- conn
  185. }()
  186. outConn, err := t.Dial(ctx, endpoints[0])
  187. if err != nil {
  188. return nil, nil, err
  189. }
  190. if err = <-errCh; err != nil {
  191. return nil, nil, err
  192. }
  193. inConn := <-acceptCh
  194. return outConn, inConn, nil
  195. }
  196. // Close implements Transport.
  197. func (t *MemoryTransport) Close() error {
  198. err := t.network.RemoveTransport(t.nodeInfo.NodeID)
  199. t.closeOnce.Do(func() {
  200. close(t.closeCh)
  201. })
  202. t.logger.Info("stopped accepting connections")
  203. return err
  204. }
  205. // Endpoints implements Transport.
  206. func (t *MemoryTransport) Endpoints() []Endpoint {
  207. select {
  208. case <-t.closeCh:
  209. return []Endpoint{}
  210. default:
  211. return []Endpoint{{
  212. Protocol: MemoryProtocol,
  213. PeerID: t.nodeInfo.NodeID,
  214. }}
  215. }
  216. }
  217. // SetChannelDescriptors implements Transport.
  218. func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) {
  219. }
  220. // MemoryConnection is an in-memory connection between two transports (nodes).
  221. type MemoryConnection struct {
  222. logger log.Logger
  223. local *MemoryTransport
  224. remote *MemoryTransport
  225. receiveCh <-chan memoryMessage
  226. sendCh chan<- memoryMessage
  227. closeCh <-chan struct{}
  228. close func() bool
  229. }
  230. // memoryMessage is used to pass messages internally in the connection.
  231. type memoryMessage struct {
  232. channel byte
  233. message []byte
  234. }
  235. // newMemoryConnection creates a new MemoryConnection. It takes all channels
  236. // (including the closeCh signal channel) on construction, such that they can be
  237. // shared between both ends of the connection.
  238. func newMemoryConnection(
  239. local *MemoryTransport,
  240. remote *MemoryTransport,
  241. receiveCh <-chan memoryMessage,
  242. sendCh chan<- memoryMessage,
  243. closeCh <-chan struct{},
  244. close func() bool,
  245. ) *MemoryConnection {
  246. c := &MemoryConnection{
  247. local: local,
  248. remote: remote,
  249. receiveCh: receiveCh,
  250. sendCh: sendCh,
  251. closeCh: closeCh,
  252. close: close,
  253. }
  254. c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
  255. return c
  256. }
  257. // ReceiveMessage implements Connection.
  258. func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
  259. // check close first, since channels are buffered
  260. select {
  261. case <-c.closeCh:
  262. return 0, nil, io.EOF
  263. default:
  264. }
  265. select {
  266. case msg := <-c.receiveCh:
  267. c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
  268. return msg.channel, msg.message, nil
  269. case <-c.closeCh:
  270. return 0, nil, io.EOF
  271. }
  272. }
  273. // SendMessage implements Connection.
  274. func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
  275. // check close first, since channels are buffered
  276. select {
  277. case <-c.closeCh:
  278. return false, io.EOF
  279. default:
  280. }
  281. select {
  282. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  283. c.logger.Debug("sent message", "channel", chID, "message", msg)
  284. return true, nil
  285. case <-c.closeCh:
  286. return false, io.EOF
  287. }
  288. }
  289. // TrySendMessage implements Connection.
  290. func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
  291. // check close first, since channels are buffered
  292. select {
  293. case <-c.closeCh:
  294. return false, io.EOF
  295. default:
  296. }
  297. select {
  298. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  299. c.logger.Debug("sent message", "channel", chID, "message", msg)
  300. return true, nil
  301. case <-c.closeCh:
  302. return false, io.EOF
  303. default:
  304. return false, nil
  305. }
  306. }
  307. // Close closes the connection.
  308. func (c *MemoryConnection) Close() error {
  309. if c.close() {
  310. c.logger.Info("closed connection")
  311. }
  312. return nil
  313. }
  314. // FlushClose flushes all pending sends and then closes the connection.
  315. func (c *MemoryConnection) FlushClose() error {
  316. return c.Close()
  317. }
  318. // LocalEndpoint returns the local endpoint for the connection.
  319. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  320. return Endpoint{
  321. PeerID: c.local.nodeInfo.NodeID,
  322. Protocol: MemoryProtocol,
  323. }
  324. }
  325. // RemoteEndpoint returns the remote endpoint for the connection.
  326. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  327. return Endpoint{
  328. PeerID: c.remote.nodeInfo.NodeID,
  329. Protocol: MemoryProtocol,
  330. }
  331. }
  332. // PubKey returns the remote peer's public key.
  333. func (c *MemoryConnection) PubKey() crypto.PubKey {
  334. return c.remote.privKey.PubKey()
  335. }
  336. // NodeInfo returns the remote peer's node info.
  337. func (c *MemoryConnection) NodeInfo() NodeInfo {
  338. return c.remote.nodeInfo
  339. }
  340. // Status returns the current connection status.
  341. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  342. return conn.ConnectionStatus{}
  343. }