You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

398 lines
10 KiB

  1. package p2p
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "sync"
  8. "github.com/tendermint/tendermint/crypto"
  9. "github.com/tendermint/tendermint/crypto/ed25519"
  10. "github.com/tendermint/tendermint/libs/log"
  11. "github.com/tendermint/tendermint/p2p/conn"
  12. )
  13. const (
  14. MemoryProtocol Protocol = "memory"
  15. )
  16. // MemoryNetwork is an in-memory "network" that uses Go channels to communicate
  17. // between endpoints. Transport endpoints are created with CreateTransport. It
  18. // is primarily used for testing.
  19. type MemoryNetwork struct {
  20. logger log.Logger
  21. mtx sync.RWMutex
  22. transports map[NodeID]*MemoryTransport
  23. }
  24. // NewMemoryNetwork creates a new in-memory network.
  25. func NewMemoryNetwork(logger log.Logger) *MemoryNetwork {
  26. return &MemoryNetwork{
  27. logger: logger,
  28. transports: map[NodeID]*MemoryTransport{},
  29. }
  30. }
  31. // CreateTransport creates a new memory transport and endpoint for the given
  32. // NodeInfo and private key. Use GenerateTransport() to autogenerate a random
  33. // key and node info.
  34. //
  35. // The transport immediately begins listening on the endpoint "memory:<id>", and
  36. // can be accessed by other transports in the same memory network.
  37. func (n *MemoryNetwork) CreateTransport(
  38. nodeInfo NodeInfo,
  39. privKey crypto.PrivKey,
  40. ) (*MemoryTransport, error) {
  41. nodeID := nodeInfo.NodeID
  42. if nodeID == "" {
  43. return nil, errors.New("no node ID")
  44. }
  45. t := newMemoryTransport(n, nodeInfo, privKey)
  46. n.mtx.Lock()
  47. defer n.mtx.Unlock()
  48. if _, ok := n.transports[nodeID]; ok {
  49. return nil, fmt.Errorf("transport with node ID %q already exists", nodeID)
  50. }
  51. n.transports[nodeID] = t
  52. return t, nil
  53. }
  54. // GenerateTransport generates a new transport endpoint by generating a random
  55. // private key and node info. The endpoint address can be obtained via
  56. // Transport.Endpoints().
  57. func (n *MemoryNetwork) GenerateTransport() *MemoryTransport {
  58. privKey := ed25519.GenPrivKey()
  59. nodeID := NodeIDFromPubKey(privKey.PubKey())
  60. nodeInfo := NodeInfo{
  61. NodeID: nodeID,
  62. ListenAddr: fmt.Sprintf("%v:%v", MemoryProtocol, nodeID),
  63. }
  64. t, err := n.CreateTransport(nodeInfo, privKey)
  65. if err != nil {
  66. // GenerateTransport is only used for testing, and the likelihood of
  67. // generating a duplicate node ID is very low, so we'll panic.
  68. panic(err)
  69. }
  70. return t
  71. }
  72. // GetTransport looks up a transport in the network, returning nil if not found.
  73. func (n *MemoryNetwork) GetTransport(id NodeID) *MemoryTransport {
  74. n.mtx.RLock()
  75. defer n.mtx.RUnlock()
  76. return n.transports[id]
  77. }
  78. // RemoveTransport removes a transport from the network and closes it.
  79. func (n *MemoryNetwork) RemoveTransport(id NodeID) error {
  80. n.mtx.Lock()
  81. t, ok := n.transports[id]
  82. delete(n.transports, id)
  83. n.mtx.Unlock()
  84. if ok {
  85. // Close may recursively call RemoveTransport() again, but this is safe
  86. // because we've already removed the transport from the map above.
  87. return t.Close()
  88. }
  89. return nil
  90. }
  91. // MemoryTransport is an in-memory transport that's primarily meant for testing.
  92. // It communicates between endpoints using Go channels. To dial a different
  93. // endpoint, both endpoints/transports must be in the same MemoryNetwork.
  94. type MemoryTransport struct {
  95. network *MemoryNetwork
  96. nodeInfo NodeInfo
  97. privKey crypto.PrivKey
  98. logger log.Logger
  99. acceptCh chan *MemoryConnection
  100. closeCh chan struct{}
  101. closeOnce sync.Once
  102. }
  103. // newMemoryTransport creates a new in-memory transport in the given network.
  104. // Callers should use MemoryNetwork.CreateTransport() or GenerateTransport()
  105. // to create transports, this is for internal use by MemoryNetwork.
  106. func newMemoryTransport(
  107. network *MemoryNetwork,
  108. nodeInfo NodeInfo,
  109. privKey crypto.PrivKey,
  110. ) *MemoryTransport {
  111. return &MemoryTransport{
  112. network: network,
  113. nodeInfo: nodeInfo,
  114. privKey: privKey,
  115. logger: network.logger.With("local",
  116. fmt.Sprintf("%v:%v", MemoryProtocol, nodeInfo.NodeID)),
  117. acceptCh: make(chan *MemoryConnection),
  118. closeCh: make(chan struct{}),
  119. }
  120. }
  121. // String displays the transport.
  122. //
  123. // FIXME: The Transport interface should either have Name() or embed
  124. // fmt.Stringer. This is necessary since we log the transport (to know which one
  125. // it is), and if it doesn't implement fmt.Stringer then it inspects all struct
  126. // contents via reflect, which triggers the race detector.
  127. func (t *MemoryTransport) String() string {
  128. return "memory"
  129. }
  130. // Accept implements Transport.
  131. func (t *MemoryTransport) Accept(ctx context.Context) (Connection, error) {
  132. select {
  133. case conn := <-t.acceptCh:
  134. t.logger.Info("accepted connection from peer", "remote", conn.RemoteEndpoint())
  135. return conn, nil
  136. case <-t.closeCh:
  137. return nil, ErrTransportClosed{}
  138. case <-ctx.Done():
  139. return nil, ctx.Err()
  140. }
  141. }
  142. // Dial implements Transport.
  143. func (t *MemoryTransport) Dial(ctx context.Context, endpoint Endpoint) (Connection, error) {
  144. if endpoint.Protocol != MemoryProtocol {
  145. return nil, fmt.Errorf("invalid protocol %q", endpoint.Protocol)
  146. }
  147. if endpoint.PeerID == "" {
  148. return nil, errors.New("no peer ID")
  149. }
  150. t.logger.Info("dialing peer", "remote", endpoint)
  151. peerTransport := t.network.GetTransport(endpoint.PeerID)
  152. if peerTransport == nil {
  153. return nil, fmt.Errorf("unknown peer %q", endpoint.PeerID)
  154. }
  155. inCh := make(chan memoryMessage, 1)
  156. outCh := make(chan memoryMessage, 1)
  157. closeCh := make(chan struct{})
  158. closeOnce := sync.Once{}
  159. closer := func() bool {
  160. closed := false
  161. closeOnce.Do(func() {
  162. close(closeCh)
  163. closed = true
  164. })
  165. return closed
  166. }
  167. outConn := newMemoryConnection(t, peerTransport, inCh, outCh, closeCh, closer)
  168. inConn := newMemoryConnection(peerTransport, t, outCh, inCh, closeCh, closer)
  169. select {
  170. case peerTransport.acceptCh <- inConn:
  171. return outConn, nil
  172. case <-peerTransport.closeCh:
  173. return nil, ErrTransportClosed{}
  174. case <-ctx.Done():
  175. return nil, ctx.Err()
  176. }
  177. }
  178. // DialAccept is a convenience function that dials a peer MemoryTransport and
  179. // returns both ends of the connection (A to B and B to A).
  180. func (t *MemoryTransport) DialAccept(
  181. ctx context.Context,
  182. peer *MemoryTransport,
  183. ) (Connection, Connection, error) {
  184. endpoints := peer.Endpoints()
  185. if len(endpoints) == 0 {
  186. return nil, nil, fmt.Errorf("peer %q not listening on any endpoints", peer.nodeInfo.NodeID)
  187. }
  188. acceptCh := make(chan Connection, 1)
  189. errCh := make(chan error, 1)
  190. go func() {
  191. conn, err := peer.Accept(ctx)
  192. errCh <- err
  193. acceptCh <- conn
  194. }()
  195. outConn, err := t.Dial(ctx, endpoints[0])
  196. if err != nil {
  197. return nil, nil, err
  198. }
  199. if err = <-errCh; err != nil {
  200. return nil, nil, err
  201. }
  202. inConn := <-acceptCh
  203. return outConn, inConn, nil
  204. }
  205. // Close implements Transport.
  206. func (t *MemoryTransport) Close() error {
  207. err := t.network.RemoveTransport(t.nodeInfo.NodeID)
  208. t.closeOnce.Do(func() {
  209. close(t.closeCh)
  210. })
  211. t.logger.Info("stopped accepting connections")
  212. return err
  213. }
  214. // Endpoints implements Transport.
  215. func (t *MemoryTransport) Endpoints() []Endpoint {
  216. select {
  217. case <-t.closeCh:
  218. return []Endpoint{}
  219. default:
  220. return []Endpoint{{
  221. Protocol: MemoryProtocol,
  222. PeerID: t.nodeInfo.NodeID,
  223. }}
  224. }
  225. }
  226. // SetChannelDescriptors implements Transport.
  227. func (t *MemoryTransport) SetChannelDescriptors(chDescs []*conn.ChannelDescriptor) {
  228. }
  229. // MemoryConnection is an in-memory connection between two transports (nodes).
  230. type MemoryConnection struct {
  231. logger log.Logger
  232. local *MemoryTransport
  233. remote *MemoryTransport
  234. receiveCh <-chan memoryMessage
  235. sendCh chan<- memoryMessage
  236. closeCh <-chan struct{}
  237. close func() bool
  238. }
  239. // memoryMessage is used to pass messages internally in the connection.
  240. type memoryMessage struct {
  241. channel byte
  242. message []byte
  243. }
  244. // newMemoryConnection creates a new MemoryConnection. It takes all channels
  245. // (including the closeCh signal channel) on construction, such that they can be
  246. // shared between both ends of the connection.
  247. func newMemoryConnection(
  248. local *MemoryTransport,
  249. remote *MemoryTransport,
  250. receiveCh <-chan memoryMessage,
  251. sendCh chan<- memoryMessage,
  252. closeCh <-chan struct{},
  253. close func() bool,
  254. ) *MemoryConnection {
  255. c := &MemoryConnection{
  256. local: local,
  257. remote: remote,
  258. receiveCh: receiveCh,
  259. sendCh: sendCh,
  260. closeCh: closeCh,
  261. close: close,
  262. }
  263. c.logger = c.local.logger.With("remote", c.RemoteEndpoint())
  264. return c
  265. }
  266. // ReceiveMessage implements Connection.
  267. func (c *MemoryConnection) ReceiveMessage() (chID byte, msg []byte, err error) {
  268. // check close first, since channels are buffered
  269. select {
  270. case <-c.closeCh:
  271. return 0, nil, io.EOF
  272. default:
  273. }
  274. select {
  275. case msg := <-c.receiveCh:
  276. c.logger.Debug("received message", "channel", msg.channel, "message", msg.message)
  277. return msg.channel, msg.message, nil
  278. case <-c.closeCh:
  279. return 0, nil, io.EOF
  280. }
  281. }
  282. // SendMessage implements Connection.
  283. func (c *MemoryConnection) SendMessage(chID byte, msg []byte) (bool, error) {
  284. // check close first, since channels are buffered
  285. select {
  286. case <-c.closeCh:
  287. return false, io.EOF
  288. default:
  289. }
  290. select {
  291. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  292. c.logger.Debug("sent message", "channel", chID, "message", msg)
  293. return true, nil
  294. case <-c.closeCh:
  295. return false, io.EOF
  296. }
  297. }
  298. // TrySendMessage implements Connection.
  299. func (c *MemoryConnection) TrySendMessage(chID byte, msg []byte) (bool, error) {
  300. // check close first, since channels are buffered
  301. select {
  302. case <-c.closeCh:
  303. return false, io.EOF
  304. default:
  305. }
  306. select {
  307. case c.sendCh <- memoryMessage{channel: chID, message: msg}:
  308. c.logger.Debug("sent message", "channel", chID, "message", msg)
  309. return true, nil
  310. case <-c.closeCh:
  311. return false, io.EOF
  312. default:
  313. return false, nil
  314. }
  315. }
  316. // Close closes the connection.
  317. func (c *MemoryConnection) Close() error {
  318. if c.close() {
  319. c.logger.Info("closed connection")
  320. }
  321. return nil
  322. }
  323. // FlushClose flushes all pending sends and then closes the connection.
  324. func (c *MemoryConnection) FlushClose() error {
  325. return c.Close()
  326. }
  327. // LocalEndpoint returns the local endpoint for the connection.
  328. func (c *MemoryConnection) LocalEndpoint() Endpoint {
  329. return Endpoint{
  330. PeerID: c.local.nodeInfo.NodeID,
  331. Protocol: MemoryProtocol,
  332. }
  333. }
  334. // RemoteEndpoint returns the remote endpoint for the connection.
  335. func (c *MemoryConnection) RemoteEndpoint() Endpoint {
  336. return Endpoint{
  337. PeerID: c.remote.nodeInfo.NodeID,
  338. Protocol: MemoryProtocol,
  339. }
  340. }
  341. // PubKey returns the remote peer's public key.
  342. func (c *MemoryConnection) PubKey() crypto.PubKey {
  343. return c.remote.privKey.PubKey()
  344. }
  345. // NodeInfo returns the remote peer's node info.
  346. func (c *MemoryConnection) NodeInfo() NodeInfo {
  347. return c.remote.nodeInfo
  348. }
  349. // Status returns the current connection status.
  350. func (c *MemoryConnection) Status() conn.ConnectionStatus {
  351. return conn.ConnectionStatus{}
  352. }