From a879eb444d0eca9d803881c80785555d312c3946 Mon Sep 17 00:00:00 2001 From: Aleksandr Bezobchuk Date: Wed, 9 Dec 2020 09:31:06 -0500 Subject: [PATCH] p2p: state sync reactor refactor (#5671) --- node/node.go | 31 +- p2p/channel.go | 120 +++++ p2p/peer.go | 131 +++++ p2p/peer_test.go | 45 ++ p2p/shim.go | 342 +++++++++++++ p2p/shim_test.go | 207 ++++++++ proto/tendermint/statesync/message.go | 94 ++++ proto/tendermint/statesync/message_test.go | 86 ++++ statesync/chunks.go | 58 ++- statesync/chunks_test.go | 41 +- statesync/messages.go | 87 +--- statesync/messages_test.go | 79 +-- statesync/reactor.go | 527 +++++++++++++++------ statesync/reactor_test.go | 244 +++++++--- statesync/snapshots.go | 65 +-- statesync/snapshots_test.go | 188 ++++---- statesync/syncer.go | 87 ++-- statesync/syncer_test.go | 387 ++++++++------- test/maverick/node/node.go | 31 +- 19 files changed, 2113 insertions(+), 737 deletions(-) create mode 100644 p2p/channel.go create mode 100644 p2p/shim.go create mode 100644 p2p/shim_test.go create mode 100644 proto/tendermint/statesync/message.go create mode 100644 proto/tendermint/statesync/message_test.go diff --git a/node/node.go b/node/node.go index fdd6debdc..81bf67490 100644 --- a/node/node.go +++ b/node/node.go @@ -483,7 +483,7 @@ func createSwitch(config *cfg.Config, peerFilters []p2p.PeerFilterFunc, mempoolReactor *mempl.Reactor, bcReactor p2p.Reactor, - stateSyncReactor *statesync.Reactor, + stateSyncReactor *p2p.ReactorShim, consensusReactor *cs.Reactor, evidenceReactor *evidence.Reactor, nodeInfo p2p.NodeInfo, @@ -746,9 +746,18 @@ func NewNode(config *cfg.Config, // FIXME The way we do phased startups (e.g. replay -> fast sync -> consensus) is very messy, // we should clean this whole thing up. See: // https://github.com/tendermint/tendermint/issues/4644 - stateSyncReactor := statesync.NewReactor(proxyApp.Snapshot(), proxyApp.Query(), - config.StateSync.TempDir) - stateSyncReactor.SetLogger(logger.With("module", "statesync")) + stateSyncReactorShim := p2p.NewReactorShim("StateSyncShim", statesync.ChannelShims) + stateSyncReactorShim.SetLogger(logger.With("module", "statesync")) + + stateSyncReactor := statesync.NewReactor( + stateSyncReactorShim.Logger, + proxyApp.Snapshot(), + proxyApp.Query(), + stateSyncReactorShim.GetChannel(statesync.SnapshotChannel), + stateSyncReactorShim.GetChannel(statesync.ChunkChannel), + stateSyncReactorShim.PeerUpdates, + config.StateSync.TempDir, + ) nodeInfo, err := makeNodeInfo(config, nodeKey, txIndexer, genDoc, state) if err != nil { @@ -762,7 +771,7 @@ func NewNode(config *cfg.Config, p2pLogger := logger.With("module", "p2p") sw := createSwitch( config, transport, p2pMetrics, peerFilters, mempoolReactor, bcReactor, - stateSyncReactor, consensusReactor, evidenceReactor, nodeInfo, nodeKey, p2pLogger, + stateSyncReactorShim, consensusReactor, evidenceReactor, nodeInfo, nodeKey, p2pLogger, ) err = sw.AddPersistentPeers(splitAndTrimEmpty(config.P2P.PersistentPeers, ",", " ")) @@ -892,6 +901,11 @@ func (n *Node) OnStart() error { return err } + // Start the real state sync reactor separately since the switch uses the shim. + if err := n.stateSyncReactor.Start(); err != nil { + return err + } + // Always connect to persistent peers err = n.sw.DialPeersAsync(splitAndTrimEmpty(n.config.P2P.PersistentPeers, ",", " ")) if err != nil { @@ -933,6 +947,11 @@ func (n *Node) OnStop() { n.Logger.Error("Error closing switch", "err", err) } + // Stop the real state sync reactor separately since the switch uses the shim. + if err := n.stateSyncReactor.Stop(); err != nil { + n.Logger.Error("failed to stop state sync service", "err", err) + } + // stop mempool WAL if n.config.Mempool.WalEnabled() { n.mempool.CloseWAL() @@ -1255,7 +1274,7 @@ func makeNodeInfo( cs.StateChannel, cs.DataChannel, cs.VoteChannel, cs.VoteSetBitsChannel, mempl.MempoolChannel, evidence.EvidenceChannel, - statesync.SnapshotChannel, statesync.ChunkChannel, + byte(statesync.SnapshotChannel), byte(statesync.ChunkChannel), }, Moniker: config.Moniker, Other: p2p.DefaultNodeInfoOther{ diff --git a/p2p/channel.go b/p2p/channel.go new file mode 100644 index 000000000..54c076a6b --- /dev/null +++ b/p2p/channel.go @@ -0,0 +1,120 @@ +package p2p + +import ( + "sync" + + "github.com/gogo/protobuf/proto" +) + +// ChannelID is an arbitrary channel ID. +type ChannelID uint16 + +// Envelope specifies the message receiver and sender. +type Envelope struct { + From PeerID // Message sender, or empty for outbound messages. + To PeerID // Message receiver, or empty for inbound messages. + Broadcast bool // Send message to all connected peers, ignoring To. + Message proto.Message // Payload. +} + +// Channel is a bidirectional channel for Protobuf message exchange with peers. +// A Channel is safe for concurrent use by multiple goroutines. +type Channel struct { + closeOnce sync.Once + + // id defines the unique channel ID. + id ChannelID + + // messageType specifies the type of messages exchanged via the channel, and + // is used e.g. for automatic unmarshaling. + messageType proto.Message + + // inCh is a channel for receiving inbound messages. Envelope.From is always + // set. + inCh chan Envelope + + // outCh is a channel for sending outbound messages. Envelope.To or Broadcast + // must be set, otherwise the message is discarded. + outCh chan Envelope + + // errCh is a channel for reporting peer errors to the router, typically used + // when peers send an invalid or malignant message. + errCh chan PeerError + + // doneCh is used to signal that a Channel is closed. A Channel is bi-directional + // and should be closed by the reactor, where as the router is responsible + // for explicitly closing the internal In channel. + doneCh chan struct{} +} + +// NewChannel returns a reference to a new p2p Channel. It is the reactor's +// responsibility to close the Channel. After a channel is closed, the router may +// safely and explicitly close the internal In channel. +func NewChannel(id ChannelID, mType proto.Message, in, out chan Envelope, errCh chan PeerError) *Channel { + return &Channel{ + id: id, + messageType: mType, + inCh: in, + outCh: out, + errCh: errCh, + doneCh: make(chan struct{}), + } +} + +// ID returns the Channel's ID. +func (c *Channel) ID() ChannelID { + return c.id +} + +// In returns a read-only inbound go channel. This go channel should be used by +// reactors to consume Envelopes sent from peers. +func (c *Channel) In() <-chan Envelope { + return c.inCh +} + +// Out returns a write-only outbound go channel. This go channel should be used +// by reactors to route Envelopes to other peers. +func (c *Channel) Out() chan<- Envelope { + return c.outCh +} + +// Error returns a write-only outbound go channel designated for peer errors only. +// This go channel should be used by reactors to send peer errors when consuming +// Envelopes sent from other peers. +func (c *Channel) Error() chan<- PeerError { + return c.errCh +} + +// Close closes the outbound channel and marks the Channel as done. Internally, +// the outbound outCh and peer error errCh channels are closed. It is the reactor's +// responsibility to invoke Close. Any send on the Out or Error channel will +// panic after the Channel is closed. +// +// NOTE: After a Channel is closed, the router may safely assume it can no longer +// send on the internal inCh, however it should NEVER explicitly close it as +// that could result in panics by sending on a closed channel. +func (c *Channel) Close() { + c.closeOnce.Do(func() { + close(c.doneCh) + close(c.outCh) + close(c.errCh) + }) +} + +// Done returns the Channel's internal channel that should be used by a router +// to signal when it is safe to send on the internal inCh go channel. +func (c *Channel) Done() <-chan struct{} { + return c.doneCh +} + +// Wrapper is a Protobuf message that can contain a variety of inner messages. +// If a Channel's message type implements Wrapper, the channel will +// automatically (un)wrap passed messages using the container type, such that +// the channel can transparently support multiple message types. +type Wrapper interface { + // Wrap will take a message and wrap it in this one. + Wrap(proto.Message) error + + // Unwrap will unwrap the inner message contained in this message. + Unwrap() (proto.Message, error) +} diff --git a/p2p/peer.go b/p2p/peer.go index 36db3d728..43a1e7c9c 100644 --- a/p2p/peer.go +++ b/p2p/peer.go @@ -1,8 +1,12 @@ package p2p import ( + "bytes" + "encoding/hex" "fmt" "net" + "strings" + "sync" "time" "github.com/tendermint/tendermint/libs/cmap" @@ -12,6 +16,133 @@ import ( tmconn "github.com/tendermint/tendermint/p2p/conn" ) +// PeerID is a unique peer ID, generally expressed in hex form. +type PeerID []byte + +// String implements the fmt.Stringer interface for the PeerID type. +func (pid PeerID) String() string { + return strings.ToLower(hex.EncodeToString(pid)) +} + +// Empty returns true if the PeerID is considered empty. +func (pid PeerID) Empty() bool { + return len(pid) == 0 +} + +// PeerIDFromString returns a PeerID from an encoded string or an error upon +// decode failure. +func PeerIDFromString(s string) (PeerID, error) { + bz, err := hex.DecodeString(s) + if err != nil { + return nil, fmt.Errorf("failed to decode PeerID (%s): %w", s, err) + } + + return PeerID(bz), nil +} + +// Equal reports whether two PeerID are equal. +func (pid PeerID) Equal(other PeerID) bool { + return bytes.Equal(pid, other) +} + +// PeerStatus specifies peer statuses. +type PeerStatus string + +const ( + PeerStatusNew = PeerStatus("new") // New peer which we haven't tried to contact yet. + PeerStatusUp = PeerStatus("up") // Peer which we have an active connection to. + PeerStatusDown = PeerStatus("down") // Peer which we're temporarily disconnected from. + PeerStatusRemoved = PeerStatus("removed") // Peer which has been removed. + PeerStatusBanned = PeerStatus("banned") // Peer which is banned for misbehavior. +) + +// PeerPriority specifies peer priorities. +type PeerPriority int + +const ( + PeerPriorityNormal PeerPriority = iota + 1 + PeerPriorityValidator + PeerPriorityPersistent +) + +// PeerError is a peer error reported by a reactor via the Error channel. The +// severity may cause the peer to be disconnected or banned depending on policy. +type PeerError struct { + PeerID PeerID + Err error + Severity PeerErrorSeverity +} + +// PeerErrorSeverity determines the severity of a peer error. +type PeerErrorSeverity string + +const ( + PeerErrorSeverityLow PeerErrorSeverity = "low" // Mostly ignored. + PeerErrorSeverityHigh PeerErrorSeverity = "high" // May disconnect. + PeerErrorSeverityCritical PeerErrorSeverity = "critical" // Ban. +) + +// PeerUpdatesCh defines a wrapper around a PeerUpdate go channel that allows +// a reactor to listen for peer updates and safely close it when stopping. +type PeerUpdatesCh struct { + closeOnce sync.Once + + // updatesCh defines the go channel in which the router sends peer updates to + // reactors. Each reactor will have its own PeerUpdatesCh to listen for updates + // from. + updatesCh chan PeerUpdate + + // doneCh is used to signal that a PeerUpdatesCh is closed. It is the + // reactor's responsibility to invoke Close. + doneCh chan struct{} +} + +// NewPeerUpdates returns a reference to a new PeerUpdatesCh. +func NewPeerUpdates() *PeerUpdatesCh { + return &PeerUpdatesCh{ + updatesCh: make(chan PeerUpdate), + doneCh: make(chan struct{}), + } +} + +// Updates returns a read-only go channel where a consuming reactor can listen +// for peer updates sent from the router. +func (puc *PeerUpdatesCh) Updates() <-chan PeerUpdate { + return puc.updatesCh +} + +// Close closes the PeerUpdatesCh channel. It should only be closed by the respective +// reactor when stopping and ensure nothing is listening for updates. +// +// NOTE: After a PeerUpdatesCh is closed, the router may safely assume it can no +// longer send on the internal updatesCh, however it should NEVER explicitly close +// it as that could result in panics by sending on a closed channel. +func (puc *PeerUpdatesCh) Close() { + puc.closeOnce.Do(func() { + close(puc.doneCh) + }) +} + +// Done returns a read-only version of the PeerUpdatesCh's internal doneCh go +// channel that should be used by a router to signal when it is safe to explicitly +// not send any peer updates. +func (puc *PeerUpdatesCh) Done() <-chan struct{} { + return puc.doneCh +} + +// PeerUpdate is a peer status update for reactors. +type PeerUpdate struct { + PeerID PeerID + Status PeerStatus +} + +// ============================================================================ +// Types and business logic below may be deprecated. +// +// TODO: Rename once legacy p2p types are removed. +// ref: https://github.com/tendermint/tendermint/issues/5670 +// ============================================================================ + //go:generate mockery --case underscore --name Peer const metricsTickerDuration = 10 * time.Second diff --git a/p2p/peer_test.go b/p2p/peer_test.go index f8808f14d..77f40d1b3 100644 --- a/p2p/peer_test.go +++ b/p2p/peer_test.go @@ -19,6 +19,51 @@ import ( tmconn "github.com/tendermint/tendermint/p2p/conn" ) +func TestPeerIDFromString(t *testing.T) { + testCases := map[string]struct { + input string + expectedID PeerID + expectErr bool + }{ + "empty peer ID string": {"", PeerID{}, false}, + "invalid peer ID string": {"foo", nil, true}, + "valid peer ID string": {"ff", PeerID{0xFF}, false}, + } + + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + pID, err := PeerIDFromString(tc.input) + require.Equal(t, tc.expectErr, err != nil, err) + require.Equal(t, tc.expectedID, pID) + }) + } +} + +func TestPeerID_String(t *testing.T) { + require.Equal(t, "", PeerID{}.String()) + require.Equal(t, "ff", PeerID{0xFF}.String()) +} + +func TestPeerID_Equal(t *testing.T) { + testCases := map[string]struct { + idA PeerID + idB PeerID + equal bool + }{ + "empty IDs": {PeerID{}, PeerID{}, true}, + "not equal": {PeerID{0xFF}, PeerID{0xAA}, false}, + "equal": {PeerID{0xFF}, PeerID{0xFF}, true}, + } + + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + require.Equal(t, tc.equal, tc.idA.Equal(tc.idB)) + }) + } +} + func TestPeerBasic(t *testing.T) { assert, require := assert.New(t), require.New(t) diff --git a/p2p/shim.go b/p2p/shim.go new file mode 100644 index 000000000..ed3be90b1 --- /dev/null +++ b/p2p/shim.go @@ -0,0 +1,342 @@ +package p2p + +import ( + "errors" + "sort" + + "github.com/gogo/protobuf/proto" +) + +// ============================================================================ +// TODO: Types and business logic below are temporary and will be removed once +// the legacy p2p stack is removed in favor of the new model. +// +// ref: https://github.com/tendermint/tendermint/issues/5670 +// ============================================================================ + +var _ Reactor = (*ReactorShim)(nil) + +type ( + messageValidator interface { + Validate() error + } + + // ReactorShim defines a generic shim wrapper around a BaseReactor. It is + // responsible for wiring up legacy p2p behavior to the new p2p semantics + // (e.g. proxying Envelope messages to legacy peers). + ReactorShim struct { + BaseReactor + + Name string + PeerUpdates *PeerUpdatesCh + Channels map[ChannelID]*ChannelShim + } + + // ChannelShim defines a generic shim wrapper around a legacy p2p channel + // and the new p2p Channel. It also includes the raw bi-directional Go channels + // so we can proxy message delivery. + ChannelShim struct { + Descriptor *ChannelDescriptor + Channel *Channel + } + + // ChannelDescriptorShim defines a shim wrapper around a legacy p2p channel + // and the proto.Message the new p2p Channel is responsible for handling. + // A ChannelDescriptorShim is not contained in ReactorShim, but is rather + // used to construct a ReactorShim. + ChannelDescriptorShim struct { + MsgType proto.Message + Descriptor *ChannelDescriptor + } +) + +func NewReactorShim(name string, descriptors map[ChannelID]*ChannelDescriptorShim) *ReactorShim { + channels := make(map[ChannelID]*ChannelShim) + + for _, cds := range descriptors { + chShim := NewChannelShim(cds, 0) + channels[chShim.Channel.id] = chShim + } + + rs := &ReactorShim{ + Name: name, + PeerUpdates: NewPeerUpdates(), + Channels: channels, + } + + rs.BaseReactor = *NewBaseReactor(name, rs) + + return rs +} + +func NewChannelShim(cds *ChannelDescriptorShim, buf uint) *ChannelShim { + return &ChannelShim{ + Descriptor: cds.Descriptor, + Channel: NewChannel( + ChannelID(cds.Descriptor.ID), + cds.MsgType, + make(chan Envelope, buf), + make(chan Envelope, buf), + make(chan PeerError, buf), + ), + } +} + +// proxyPeerEnvelopes iterates over each p2p Channel and starts a separate +// go-routine where we listen for outbound envelopes sent during Receive +// executions (or anything else that may send on the Channel) and proxy them to +// the corresponding Peer using the To field from the envelope. +func (rs *ReactorShim) proxyPeerEnvelopes() { + for _, cs := range rs.Channels { + go func(cs *ChannelShim) { + for e := range cs.Channel.outCh { + msg := proto.Clone(cs.Channel.messageType) + msg.Reset() + + wrapper, ok := msg.(Wrapper) + if ok { + if err := wrapper.Wrap(e.Message); err != nil { + rs.Logger.Error( + "failed to proxy envelope; failed to wrap message", + "ch_id", cs.Descriptor.ID, + "msg", e.Message, + "err", err, + ) + continue + } + } else { + msg = e.Message + } + + bz, err := proto.Marshal(msg) + if err != nil { + rs.Logger.Error( + "failed to proxy envelope; failed to encode message", + "ch_id", cs.Descriptor.ID, + "msg", e.Message, + "err", err, + ) + continue + } + + switch { + case e.Broadcast: + rs.Switch.Broadcast(cs.Descriptor.ID, bz) + + case !e.To.Empty(): + src := rs.Switch.peers.Get(ID(e.To.String())) + if src == nil { + rs.Logger.Error( + "failed to proxy envelope; failed to find peer", + "ch_id", cs.Descriptor.ID, + "msg", e.Message, + "peer", e.To.String(), + ) + continue + } + + if !src.Send(cs.Descriptor.ID, bz) { + rs.Logger.Error( + "failed to proxy message to peer", + "ch_id", cs.Descriptor.ID, + "msg", e.Message, + "peer", e.To.String(), + ) + } + + default: + rs.Logger.Error("failed to proxy envelope; missing peer ID", "ch_id", cs.Descriptor.ID, "msg", e.Message) + } + } + }(cs) + } +} + +// handlePeerErrors iterates over each p2p Channel and starts a separate go-routine +// where we listen for peer errors. For each peer error, we find the peer from +// the legacy p2p Switch and execute a StopPeerForError call with the corresponding +// peer error. +func (rs *ReactorShim) handlePeerErrors() { + for _, cs := range rs.Channels { + go func(cs *ChannelShim) { + for pErr := range cs.Channel.errCh { + if !pErr.PeerID.Empty() { + peer := rs.Switch.peers.Get(ID(pErr.PeerID.String())) + if peer == nil { + rs.Logger.Error("failed to handle peer error; failed to find peer", "peer", pErr.PeerID.String()) + continue + } + + rs.Switch.StopPeerForError(peer, pErr.Err) + } + } + }(cs) + } +} + +// OnStart executes the reactor shim's OnStart hook where we start all the +// necessary go-routines in order to proxy peer envelopes and errors per p2p +// Channel. +func (rs *ReactorShim) OnStart() error { + if rs.Switch == nil { + return errors.New("proxyPeerEnvelopes: reactor shim switch is nil") + } + + // start envelope proxying and peer error handling in separate go routines + rs.proxyPeerEnvelopes() + rs.handlePeerErrors() + + return nil +} + +// GetChannel returns a p2p Channel reference for a given ChannelID. If no +// Channel exists, nil is returned. +func (rs *ReactorShim) GetChannel(cID ChannelID) *Channel { + channelShim, ok := rs.Channels[cID] + if ok { + return channelShim.Channel + } + + return nil +} + +// GetChannels implements the legacy Reactor interface for getting a slice of all +// the supported ChannelDescriptors. +func (rs *ReactorShim) GetChannels() []*ChannelDescriptor { + sortedChIDs := make([]ChannelID, 0, len(rs.Channels)) + for cID := range rs.Channels { + sortedChIDs = append(sortedChIDs, cID) + } + + sort.Slice(sortedChIDs, func(i, j int) bool { return sortedChIDs[i] < sortedChIDs[j] }) + + descriptors := make([]*ChannelDescriptor, len(rs.Channels)) + for i, cID := range sortedChIDs { + descriptors[i] = rs.Channels[cID].Descriptor + } + + return descriptors +} + +// AddPeer sends a PeerUpdate with status PeerStatusUp on the PeerUpdateCh. +// The embedding reactor must be sure to listen for messages on this channel to +// handle adding a peer. +func (rs *ReactorShim) AddPeer(peer Peer) { + peerID, err := PeerIDFromString(string(peer.ID())) + if err != nil { + rs.Logger.Error("failed to add peer", "peer", peer.ID(), "err", err) + return + } + + select { + case rs.PeerUpdates.updatesCh <- PeerUpdate{PeerID: peerID, Status: PeerStatusUp}: + rs.Logger.Debug("sent peer update", "reactor", rs.Name, "peer", peerID.String(), "status", PeerStatusUp) + + case <-rs.PeerUpdates.Done(): + // NOTE: We explicitly DO NOT close the PeerUpdatesCh's updateCh go channel. + // This is because there may be numerous spawned goroutines that are + // attempting to send on the updateCh go channel and when the reactor stops + // we do not want to preemptively close the channel as that could result in + // panics sending on a closed channel. This also means that reactors MUST + // be certain there are NO listeners on the updateCh channel when closing or + // stopping. + } +} + +// RemovePeer sends a PeerUpdate with status PeerStatusDown on the PeerUpdateCh. +// The embedding reactor must be sure to listen for messages on this channel to +// handle removing a peer. +func (rs *ReactorShim) RemovePeer(peer Peer, reason interface{}) { + peerID, err := PeerIDFromString(string(peer.ID())) + if err != nil { + rs.Logger.Error("failed to remove peer", "peer", peer.ID(), "err", err) + return + } + + select { + case rs.PeerUpdates.updatesCh <- PeerUpdate{PeerID: peerID, Status: PeerStatusDown}: + rs.Logger.Debug( + "sent peer update", + "reactor", rs.Name, + "peer", peerID.String(), + "reason", reason, + "status", PeerStatusDown, + ) + + case <-rs.PeerUpdates.Done(): + // NOTE: We explicitly DO NOT close the PeerUpdatesCh's updateCh go channel. + // This is because there may be numerous spawned goroutines that are + // attempting to send on the updateCh go channel and when the reactor stops + // we do not want to preemptively close the channel as that could result in + // panics sending on a closed channel. This also means that reactors MUST + // be certain there are NO listeners on the updateCh channel when closing or + // stopping. + } +} + +// Receive implements a generic wrapper around implementing the Receive method +// on the legacy Reactor p2p interface. If the reactor is running, Receive will +// find the corresponding new p2p Channel, create and decode the appropriate +// proto.Message from the msgBytes, execute any validation and finally construct +// and send a p2p Envelope on the appropriate p2p Channel. +func (rs *ReactorShim) Receive(chID byte, src Peer, msgBytes []byte) { + if !rs.IsRunning() { + return + } + + cID := ChannelID(chID) + channelShim, ok := rs.Channels[cID] + if !ok { + rs.Logger.Error("unexpected channel", "peer", src, "ch_id", chID) + return + } + + peerID, err := PeerIDFromString(string(src.ID())) + if err != nil { + rs.Logger.Error("failed to convert peer ID", "peer", src, "ch_id", chID, "err", err) + return + } + + msg := proto.Clone(channelShim.Channel.messageType) + msg.Reset() + + if err := proto.Unmarshal(msgBytes, msg); err != nil { + rs.Logger.Error("error decoding message", "peer", src, "ch_id", cID, "msg", msg, "err", err) + rs.Switch.StopPeerForError(src, err) + return + } + + validator, ok := msg.(messageValidator) + if ok { + if err := validator.Validate(); err != nil { + rs.Logger.Error("invalid message", "peer", src, "ch_id", cID, "msg", msg, "err", err) + rs.Switch.StopPeerForError(src, err) + return + } + } + + wrapper, ok := msg.(Wrapper) + if ok { + var err error + + msg, err = wrapper.Unwrap() + if err != nil { + rs.Logger.Error("failed to unwrap message", "peer", src, "ch_id", chID, "msg", msg, "err", err) + return + } + } + + select { + case channelShim.Channel.inCh <- Envelope{From: peerID, Message: msg}: + rs.Logger.Debug("proxied envelope", "reactor", rs.Name, "ch_id", cID, "peer", peerID.String()) + + case <-channelShim.Channel.Done(): + // NOTE: We explicitly DO NOT close the p2p Channel's inbound go channel. + // This is because there may be numerous spawned goroutines that are + // attempting to send on the inbound channel and when the reactor stops we + // do not want to preemptively close the channel as that could result in + // panics sending on a closed channel. This also means that reactors MUST + // be certain there are NO listeners on the inbound channel when closing or + // stopping. + } +} diff --git a/p2p/shim_test.go b/p2p/shim_test.go new file mode 100644 index 000000000..218fa8d99 --- /dev/null +++ b/p2p/shim_test.go @@ -0,0 +1,207 @@ +package p2p_test + +import ( + "sync" + "testing" + + "github.com/gogo/protobuf/proto" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/tendermint/tendermint/config" + "github.com/tendermint/tendermint/p2p" + p2pmocks "github.com/tendermint/tendermint/p2p/mocks" + ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" +) + +var ( + channelID1 = byte(0x01) + channelID2 = byte(0x02) + + p2pCfg = config.DefaultP2PConfig() + + testChannelShims = map[p2p.ChannelID]*p2p.ChannelDescriptorShim{ + p2p.ChannelID(channelID1): { + MsgType: new(ssproto.Message), + Descriptor: &p2p.ChannelDescriptor{ + ID: channelID1, + Priority: 3, + SendQueueCapacity: 10, + RecvMessageCapacity: int(4e6), + }, + }, + p2p.ChannelID(channelID2): { + MsgType: new(ssproto.Message), + Descriptor: &p2p.ChannelDescriptor{ + ID: channelID2, + Priority: 1, + SendQueueCapacity: 4, + RecvMessageCapacity: int(16e6), + }, + }, + } +) + +type reactorShimTestSuite struct { + shim *p2p.ReactorShim + sw *p2p.Switch +} + +func setup(t *testing.T, peers []p2p.Peer) *reactorShimTestSuite { + t.Helper() + + rts := &reactorShimTestSuite{ + shim: p2p.NewReactorShim("TestShim", testChannelShims), + } + + rts.sw = p2p.MakeSwitch(p2pCfg, 1, "testing", "123.123.123", func(_ int, sw *p2p.Switch) *p2p.Switch { + for _, peer := range peers { + p2p.AddPeerToSwitchPeerSet(sw, peer) + } + + sw.AddReactor(rts.shim.Name, rts.shim) + return sw + }) + + // start the reactor shim + require.NoError(t, rts.shim.Start()) + + t.Cleanup(func() { + require.NoError(t, rts.shim.Stop()) + + for _, chs := range rts.shim.Channels { + chs.Channel.Close() + } + }) + + return rts +} + +func simplePeer(t *testing.T, id string) (*p2pmocks.Peer, p2p.PeerID) { + t.Helper() + + peer := &p2pmocks.Peer{} + peer.On("ID").Return(p2p.ID(id)) + + pID, err := p2p.PeerIDFromString(string(peer.ID())) + require.NoError(t, err) + + return peer, pID +} + +func TestReactorShim_GetChannel(t *testing.T) { + rts := setup(t, nil) + + p2pCh := rts.shim.GetChannel(p2p.ChannelID(channelID1)) + require.NotNil(t, p2pCh) + require.Equal(t, p2pCh.ID(), p2p.ChannelID(channelID1)) + + p2pCh = rts.shim.GetChannel(p2p.ChannelID(byte(0x03))) + require.Nil(t, p2pCh) +} + +func TestReactorShim_GetChannels(t *testing.T) { + rts := setup(t, nil) + + p2pChs := rts.shim.GetChannels() + require.Len(t, p2pChs, 2) + require.Equal(t, p2p.ChannelID(p2pChs[0].ID), p2p.ChannelID(channelID1)) + require.Equal(t, p2p.ChannelID(p2pChs[1].ID), p2p.ChannelID(channelID2)) +} + +func TestReactorShim_AddPeer(t *testing.T) { + peerA, peerIDA := simplePeer(t, "aa") + rts := setup(t, []p2p.Peer{peerA}) + + var wg sync.WaitGroup + wg.Add(1) + + var peerUpdate p2p.PeerUpdate + go func() { + peerUpdate = <-rts.shim.PeerUpdates.Updates() + wg.Done() + }() + + rts.shim.AddPeer(peerA) + wg.Wait() + + require.Equal(t, peerIDA, peerUpdate.PeerID) + require.Equal(t, p2p.PeerStatusUp, peerUpdate.Status) +} + +func TestReactorShim_RemovePeer(t *testing.T) { + peerA, peerIDA := simplePeer(t, "aa") + rts := setup(t, []p2p.Peer{peerA}) + + var wg sync.WaitGroup + wg.Add(1) + + var peerUpdate p2p.PeerUpdate + go func() { + peerUpdate = <-rts.shim.PeerUpdates.Updates() + wg.Done() + }() + + rts.shim.RemovePeer(peerA, "test reason") + wg.Wait() + + require.Equal(t, peerIDA, peerUpdate.PeerID) + require.Equal(t, p2p.PeerStatusDown, peerUpdate.Status) +} + +func TestReactorShim_Receive(t *testing.T) { + peerA, peerIDA := simplePeer(t, "aa") + rts := setup(t, []p2p.Peer{peerA}) + + msg := &ssproto.Message{ + Sum: &ssproto.Message_ChunkRequest{ + ChunkRequest: &ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, + }, + } + + bz, err := proto.Marshal(msg) + require.NoError(t, err) + + var wg sync.WaitGroup + + var response *ssproto.Message + peerA.On("Send", channelID1, mock.Anything).Run(func(args mock.Arguments) { + m := &ssproto.Message{} + require.NoError(t, proto.Unmarshal(args[1].([]byte), m)) + + response = m + wg.Done() + }).Return(true) + + p2pCh := rts.shim.Channels[p2p.ChannelID(channelID1)] + + wg.Add(2) + + // Simulate receiving the envelope in some real reactor and replying back with + // the same envelope and then closing the Channel. + go func() { + e := <-p2pCh.Channel.In() + require.Equal(t, peerIDA, e.From) + require.NotNil(t, e.Message) + + p2pCh.Channel.Out() <- p2p.Envelope{To: e.From, Message: e.Message} + p2pCh.Channel.Close() + wg.Done() + }() + + rts.shim.Receive(channelID1, peerA, bz) + + // wait until the mock peer called Send and we (fake) proxied the envelope + wg.Wait() + require.NotNil(t, response) + + m, err := response.Unwrap() + require.NoError(t, err) + require.Equal(t, msg.GetChunkRequest(), m) + + // Since p2pCh was closed in the simulated reactor above, calling Receive + // should not block. + rts.shim.Receive(channelID1, peerA, bz) + require.Empty(t, p2pCh.Channel.In()) + + peerA.AssertExpectations(t) +} diff --git a/proto/tendermint/statesync/message.go b/proto/tendermint/statesync/message.go new file mode 100644 index 000000000..792e7f64c --- /dev/null +++ b/proto/tendermint/statesync/message.go @@ -0,0 +1,94 @@ +package statesync + +import ( + "errors" + fmt "fmt" + + proto "github.com/gogo/protobuf/proto" +) + +// Wrap implements the p2p Wrapper interface and wraps a state sync messages. +func (m *Message) Wrap(msg proto.Message) error { + switch msg := msg.(type) { + case *ChunkRequest: + m.Sum = &Message_ChunkRequest{ChunkRequest: msg} + + case *ChunkResponse: + m.Sum = &Message_ChunkResponse{ChunkResponse: msg} + + case *SnapshotsRequest: + m.Sum = &Message_SnapshotsRequest{SnapshotsRequest: msg} + + case *SnapshotsResponse: + m.Sum = &Message_SnapshotsResponse{SnapshotsResponse: msg} + + default: + return fmt.Errorf("unknown message: %T", msg) + } + + return nil +} + +// Unwrap implements the p2p Wrapper interface and unwraps a wrapped state sync +// message. +func (m *Message) Unwrap() (proto.Message, error) { + switch msg := m.Sum.(type) { + case *Message_ChunkRequest: + return m.GetChunkRequest(), nil + + case *Message_ChunkResponse: + return m.GetChunkResponse(), nil + + case *Message_SnapshotsRequest: + return m.GetSnapshotsRequest(), nil + + case *Message_SnapshotsResponse: + return m.GetSnapshotsResponse(), nil + + default: + return nil, fmt.Errorf("unknown message: %T", msg) + } +} + +// Validate validates the message returning an error upon failure. +func (m *Message) Validate() error { + if m == nil { + return errors.New("message cannot be nil") + } + + switch msg := m.Sum.(type) { + case *Message_ChunkRequest: + if m.GetChunkRequest().Height == 0 { + return errors.New("height cannot be 0") + } + + case *Message_ChunkResponse: + if m.GetChunkResponse().Height == 0 { + return errors.New("height cannot be 0") + } + if m.GetChunkResponse().Missing && len(m.GetChunkResponse().Chunk) > 0 { + return errors.New("missing chunk cannot have contents") + } + if !m.GetChunkResponse().Missing && m.GetChunkResponse().Chunk == nil { + return errors.New("chunk cannot be nil") + } + + case *Message_SnapshotsRequest: + + case *Message_SnapshotsResponse: + if m.GetSnapshotsResponse().Height == 0 { + return errors.New("height cannot be 0") + } + if len(m.GetSnapshotsResponse().Hash) == 0 { + return errors.New("snapshot has no hash") + } + if m.GetSnapshotsResponse().Chunks == 0 { + return errors.New("snapshot has no chunks") + } + + default: + return fmt.Errorf("unknown message type: %T", msg) + } + + return nil +} diff --git a/proto/tendermint/statesync/message_test.go b/proto/tendermint/statesync/message_test.go new file mode 100644 index 000000000..192e17c6d --- /dev/null +++ b/proto/tendermint/statesync/message_test.go @@ -0,0 +1,86 @@ +package statesync_test + +import ( + "testing" + + proto "github.com/gogo/protobuf/proto" + "github.com/stretchr/testify/require" + + ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" + tmproto "github.com/tendermint/tendermint/proto/tendermint/types" +) + +func TestValidateMsg(t *testing.T) { + testcases := map[string]struct { + msg proto.Message + valid bool + }{ + "nil": {nil, false}, + "unrelated": {&tmproto.Block{}, false}, + + "ChunkRequest valid": {&ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, true}, + "ChunkRequest 0 height": {&ssproto.ChunkRequest{Height: 0, Format: 1, Index: 1}, false}, + "ChunkRequest 0 format": {&ssproto.ChunkRequest{Height: 1, Format: 0, Index: 1}, true}, + "ChunkRequest 0 chunk": {&ssproto.ChunkRequest{Height: 1, Format: 1, Index: 0}, true}, + + "ChunkResponse valid": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{1}}, + true}, + "ChunkResponse 0 height": { + &ssproto.ChunkResponse{Height: 0, Format: 1, Index: 1, Chunk: []byte{1}}, + false}, + "ChunkResponse 0 format": { + &ssproto.ChunkResponse{Height: 1, Format: 0, Index: 1, Chunk: []byte{1}}, + true}, + "ChunkResponse 0 chunk": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 0, Chunk: []byte{1}}, + true}, + "ChunkResponse empty body": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{}}, + true}, + "ChunkResponse nil body": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: nil}, + false}, + "ChunkResponse missing": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true}, + true}, + "ChunkResponse missing with empty": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true, Chunk: []byte{}}, + true}, + "ChunkResponse missing with body": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true, Chunk: []byte{1}}, + false}, + + "SnapshotsRequest valid": {&ssproto.SnapshotsRequest{}, true}, + + "SnapshotsResponse valid": { + &ssproto.SnapshotsResponse{Height: 1, Format: 1, Chunks: 2, Hash: []byte{1}}, + true}, + "SnapshotsResponse 0 height": { + &ssproto.SnapshotsResponse{Height: 0, Format: 1, Chunks: 2, Hash: []byte{1}}, + false}, + "SnapshotsResponse 0 format": { + &ssproto.SnapshotsResponse{Height: 1, Format: 0, Chunks: 2, Hash: []byte{1}}, + true}, + "SnapshotsResponse 0 chunks": { + &ssproto.SnapshotsResponse{Height: 1, Format: 1, Hash: []byte{1}}, + false}, + "SnapshotsResponse no hash": { + &ssproto.SnapshotsResponse{Height: 1, Format: 1, Chunks: 2, Hash: []byte{}}, + false}, + } + + for name, tc := range testcases { + tc := tc + t.Run(name, func(t *testing.T) { + msg := new(ssproto.Message) + _ = msg.Wrap(tc.msg) + + if tc.valid { + require.NoError(t, msg.Validate()) + } else { + require.Error(t, msg.Validate()) + } + }) + } +} diff --git a/statesync/chunks.go b/statesync/chunks.go index 028c863b9..21b48bad2 100644 --- a/statesync/chunks.go +++ b/statesync/chunks.go @@ -22,7 +22,7 @@ type chunk struct { Format uint32 Index uint32 Chunk []byte - Sender p2p.ID + Sender p2p.PeerID } // chunkQueue manages chunks for a state sync process, ordering them if requested. It acts as an @@ -33,7 +33,7 @@ type chunkQueue struct { snapshot *snapshot // if this is nil, the queue has been closed dir string // temp dir for on-disk chunk storage chunkFiles map[uint32]string // path to temporary chunk file - chunkSenders map[uint32]p2p.ID // the peer who sent the given chunk + chunkSenders map[uint32]p2p.PeerID // the peer who sent the given chunk chunkAllocated map[uint32]bool // chunks that have been allocated via Allocate() chunkReturned map[uint32]bool // chunks returned via Next() waiters map[uint32][]chan<- uint32 // signals WaitFor() waiters about chunk arrival @@ -49,11 +49,12 @@ func newChunkQueue(snapshot *snapshot, tempDir string) (*chunkQueue, error) { if snapshot.Chunks == 0 { return nil, errors.New("snapshot has no chunks") } + return &chunkQueue{ snapshot: snapshot, dir: dir, chunkFiles: make(map[uint32]string, snapshot.Chunks), - chunkSenders: make(map[uint32]p2p.ID, snapshot.Chunks), + chunkSenders: make(map[uint32]p2p.PeerID, snapshot.Chunks), chunkAllocated: make(map[uint32]bool, snapshot.Chunks), chunkReturned: make(map[uint32]bool, snapshot.Chunks), waiters: make(map[uint32][]chan<- uint32), @@ -65,8 +66,10 @@ func (q *chunkQueue) Add(chunk *chunk) (bool, error) { if chunk == nil || chunk.Chunk == nil { return false, errors.New("cannot add nil chunk") } + q.Lock() defer q.Unlock() + if q.snapshot == nil { return false, nil // queue is closed } @@ -88,6 +91,7 @@ func (q *chunkQueue) Add(chunk *chunk) (bool, error) { if err != nil { return false, fmt.Errorf("failed to save chunk %v to file %v: %w", chunk.Index, path, err) } + q.chunkFiles[chunk.Index] = path q.chunkSenders[chunk.Index] = chunk.Sender @@ -96,6 +100,7 @@ func (q *chunkQueue) Add(chunk *chunk) (bool, error) { waiter <- chunk.Index close(waiter) } + delete(q.waiters, chunk.Index) return true, nil @@ -106,18 +111,22 @@ func (q *chunkQueue) Add(chunk *chunk) (bool, error) { func (q *chunkQueue) Allocate() (uint32, error) { q.Lock() defer q.Unlock() + if q.snapshot == nil { return 0, errDone } + if uint32(len(q.chunkAllocated)) >= q.snapshot.Chunks { return 0, errDone } + for i := uint32(0); i < q.snapshot.Chunks; i++ { if !q.chunkAllocated[i] { q.chunkAllocated[i] = true return i, nil } } + return 0, errDone } @@ -125,20 +134,24 @@ func (q *chunkQueue) Allocate() (uint32, error) { func (q *chunkQueue) Close() error { q.Lock() defer q.Unlock() + if q.snapshot == nil { return nil } + for _, waiters := range q.waiters { for _, waiter := range waiters { close(waiter) } } + q.waiters = nil q.snapshot = nil - err := os.RemoveAll(q.dir) - if err != nil { + + if err := os.RemoveAll(q.dir); err != nil { return fmt.Errorf("failed to clean up state sync tempdir %v: %w", q.dir, err) } + return nil } @@ -156,40 +169,46 @@ func (q *chunkQueue) discard(index uint32) error { if q.snapshot == nil { return nil } + path := q.chunkFiles[index] if path == "" { return nil } - err := os.Remove(path) - if err != nil { + + if err := os.Remove(path); err != nil { return fmt.Errorf("failed to remove chunk %v: %w", index, err) } + delete(q.chunkFiles, index) delete(q.chunkReturned, index) delete(q.chunkAllocated, index) + return nil } // DiscardSender discards all *unreturned* chunks from a given sender. If the caller wants to // discard already returned chunks, this can be done via Discard(). -func (q *chunkQueue) DiscardSender(peerID p2p.ID) error { +func (q *chunkQueue) DiscardSender(peerID p2p.PeerID) error { q.Lock() defer q.Unlock() for index, sender := range q.chunkSenders { - if sender == peerID && !q.chunkReturned[index] { + if sender.Equal(peerID) && !q.chunkReturned[index] { err := q.discard(index) if err != nil { return err } + delete(q.chunkSenders, index) } } + return nil } -// GetSender returns the sender of the chunk with the given index, or empty if not found. -func (q *chunkQueue) GetSender(index uint32) p2p.ID { +// GetSender returns the sender of the chunk with the given index, or empty if +// not found. +func (q *chunkQueue) GetSender(index uint32) p2p.PeerID { q.Lock() defer q.Unlock() return q.chunkSenders[index] @@ -209,10 +228,12 @@ func (q *chunkQueue) load(index uint32) (*chunk, error) { if !ok { return nil, nil } + body, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("failed to load chunk %v: %w", index, err) } + return &chunk{ Height: q.snapshot.Height, Format: q.snapshot.Format, @@ -226,6 +247,7 @@ func (q *chunkQueue) load(index uint32) (*chunk, error) { // blocks until the chunk is available. Concurrent Next() calls may return the same chunk. func (q *chunkQueue) Next() (*chunk, error) { q.Lock() + var chunk *chunk index, err := q.nextUp() if err == nil { @@ -234,7 +256,9 @@ func (q *chunkQueue) Next() (*chunk, error) { q.chunkReturned[index] = true } } + q.Unlock() + if chunk != nil || err != nil { return chunk, err } @@ -250,10 +274,12 @@ func (q *chunkQueue) Next() (*chunk, error) { q.Lock() defer q.Unlock() + chunk, err = q.load(index) if err != nil { return nil, err } + q.chunkReturned[index] = true return chunk, nil } @@ -264,11 +290,13 @@ func (q *chunkQueue) nextUp() (uint32, error) { if q.snapshot == nil { return 0, errDone } + for i := uint32(0); i < q.snapshot.Chunks; i++ { if !q.chunkReturned[i] { return i, nil } } + return 0, errDone } @@ -290,9 +318,11 @@ func (q *chunkQueue) RetryAll() { func (q *chunkQueue) Size() uint32 { q.Lock() defer q.Unlock() + if q.snapshot == nil { return 0 } + return q.snapshot.Chunks } @@ -302,20 +332,26 @@ func (q *chunkQueue) Size() uint32 { func (q *chunkQueue) WaitFor(index uint32) <-chan uint32 { q.Lock() defer q.Unlock() + ch := make(chan uint32, 1) switch { case q.snapshot == nil: close(ch) + case index >= q.snapshot.Chunks: close(ch) + case q.chunkFiles[index] != "": ch <- index close(ch) + default: if q.waiters[index] == nil { q.waiters[index] = make([]chan<- uint32, 0) } + q.waiters[index] = append(q.waiters[index], ch) } + return ch } diff --git a/statesync/chunks_test.go b/statesync/chunks_test.go index 2b9a5d751..40258b5e7 100644 --- a/statesync/chunks_test.go +++ b/statesync/chunks_test.go @@ -274,7 +274,7 @@ func TestChunkQueue_DiscardSender(t *testing.T) { defer teardown() // Allocate and add all chunks to the queue - senders := []p2p.ID{"a", "b", "c"} + senders := []p2p.PeerID{p2p.PeerID("a"), p2p.PeerID("b"), p2p.PeerID("c")} for i := uint32(0); i < queue.Size(); i++ { _, err := queue.Allocate() require.NoError(t, err) @@ -295,14 +295,14 @@ func TestChunkQueue_DiscardSender(t *testing.T) { } // Discarding an unknown sender should do nothing - err := queue.DiscardSender("x") + err := queue.DiscardSender(p2p.PeerID("x")) require.NoError(t, err) _, err = queue.Allocate() assert.Equal(t, errDone, err) // Discarding sender b should discard chunk 4, but not chunk 1 which has already been // returned. - err = queue.DiscardSender("b") + err = queue.DiscardSender(p2p.PeerID("b")) require.NoError(t, err) index, err := queue.Allocate() require.NoError(t, err) @@ -315,21 +315,24 @@ func TestChunkQueue_GetSender(t *testing.T) { queue, teardown := setupChunkQueue(t) defer teardown() - _, err := queue.Add(&chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{1}, Sender: p2p.ID("a")}) + peerAID := p2p.PeerID{0xaa} + peerBID := p2p.PeerID{0xbb} + + _, err := queue.Add(&chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{1}, Sender: peerAID}) require.NoError(t, err) - _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{2}, Sender: p2p.ID("b")}) + _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{2}, Sender: peerBID}) require.NoError(t, err) - assert.EqualValues(t, "a", queue.GetSender(0)) - assert.EqualValues(t, "b", queue.GetSender(1)) - assert.EqualValues(t, "", queue.GetSender(2)) + assert.Equal(t, "aa", queue.GetSender(0).String()) + assert.Equal(t, "bb", queue.GetSender(1).String()) + assert.Equal(t, "", queue.GetSender(2).String()) // After the chunk has been processed, we should still know who the sender was chunk, err := queue.Next() require.NoError(t, err) require.NotNil(t, chunk) require.EqualValues(t, 0, chunk.Index) - assert.EqualValues(t, "a", queue.GetSender(0)) + assert.Equal(t, "aa", queue.GetSender(0).String()) } func TestChunkQueue_Next(t *testing.T) { @@ -351,7 +354,7 @@ func TestChunkQueue_Next(t *testing.T) { }() assert.Empty(t, chNext) - _, err := queue.Add(&chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{3, 1, 1}, Sender: p2p.ID("b")}) + _, err := queue.Add(&chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{3, 1, 1}, Sender: p2p.PeerID("b")}) require.NoError(t, err) select { case <-chNext: @@ -359,17 +362,17 @@ func TestChunkQueue_Next(t *testing.T) { default: } - _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{3, 1, 0}, Sender: p2p.ID("a")}) + _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{3, 1, 0}, Sender: p2p.PeerID("a")}) require.NoError(t, err) assert.Equal(t, - &chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{3, 1, 0}, Sender: p2p.ID("a")}, + &chunk{Height: 3, Format: 1, Index: 0, Chunk: []byte{3, 1, 0}, Sender: p2p.PeerID("a")}, <-chNext) assert.Equal(t, - &chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{3, 1, 1}, Sender: p2p.ID("b")}, + &chunk{Height: 3, Format: 1, Index: 1, Chunk: []byte{3, 1, 1}, Sender: p2p.PeerID("b")}, <-chNext) - _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 4, Chunk: []byte{3, 1, 4}, Sender: p2p.ID("e")}) + _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 4, Chunk: []byte{3, 1, 4}, Sender: p2p.PeerID("e")}) require.NoError(t, err) select { case <-chNext: @@ -377,19 +380,19 @@ func TestChunkQueue_Next(t *testing.T) { default: } - _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 2, Chunk: []byte{3, 1, 2}, Sender: p2p.ID("c")}) + _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 2, Chunk: []byte{3, 1, 2}, Sender: p2p.PeerID("c")}) require.NoError(t, err) - _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 3, Chunk: []byte{3, 1, 3}, Sender: p2p.ID("d")}) + _, err = queue.Add(&chunk{Height: 3, Format: 1, Index: 3, Chunk: []byte{3, 1, 3}, Sender: p2p.PeerID("d")}) require.NoError(t, err) assert.Equal(t, - &chunk{Height: 3, Format: 1, Index: 2, Chunk: []byte{3, 1, 2}, Sender: p2p.ID("c")}, + &chunk{Height: 3, Format: 1, Index: 2, Chunk: []byte{3, 1, 2}, Sender: p2p.PeerID("c")}, <-chNext) assert.Equal(t, - &chunk{Height: 3, Format: 1, Index: 3, Chunk: []byte{3, 1, 3}, Sender: p2p.ID("d")}, + &chunk{Height: 3, Format: 1, Index: 3, Chunk: []byte{3, 1, 3}, Sender: p2p.PeerID("d")}, <-chNext) assert.Equal(t, - &chunk{Height: 3, Format: 1, Index: 4, Chunk: []byte{3, 1, 4}, Sender: p2p.ID("e")}, + &chunk{Height: 3, Format: 1, Index: 4, Chunk: []byte{3, 1, 4}, Sender: p2p.PeerID("e")}, <-chNext) _, ok := <-chNext diff --git a/statesync/messages.go b/statesync/messages.go index b07227bbf..7d556aa09 100644 --- a/statesync/messages.go +++ b/statesync/messages.go @@ -1,11 +1,7 @@ package statesync import ( - "errors" - "fmt" - - "github.com/gogo/protobuf/proto" - + "github.com/tendermint/tendermint/p2p" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" ) @@ -16,82 +12,5 @@ const ( chunkMsgSize = int(16e6) ) -// mustEncodeMsg encodes a Protobuf message, panicing on error. -func mustEncodeMsg(pb proto.Message) []byte { - msg := ssproto.Message{} - switch pb := pb.(type) { - case *ssproto.ChunkRequest: - msg.Sum = &ssproto.Message_ChunkRequest{ChunkRequest: pb} - case *ssproto.ChunkResponse: - msg.Sum = &ssproto.Message_ChunkResponse{ChunkResponse: pb} - case *ssproto.SnapshotsRequest: - msg.Sum = &ssproto.Message_SnapshotsRequest{SnapshotsRequest: pb} - case *ssproto.SnapshotsResponse: - msg.Sum = &ssproto.Message_SnapshotsResponse{SnapshotsResponse: pb} - default: - panic(fmt.Errorf("unknown message type %T", pb)) - } - bz, err := msg.Marshal() - if err != nil { - panic(fmt.Errorf("unable to marshal %T: %w", pb, err)) - } - return bz -} - -// decodeMsg decodes a Protobuf message. -func decodeMsg(bz []byte) (proto.Message, error) { - pb := &ssproto.Message{} - err := proto.Unmarshal(bz, pb) - if err != nil { - return nil, err - } - switch msg := pb.Sum.(type) { - case *ssproto.Message_ChunkRequest: - return msg.ChunkRequest, nil - case *ssproto.Message_ChunkResponse: - return msg.ChunkResponse, nil - case *ssproto.Message_SnapshotsRequest: - return msg.SnapshotsRequest, nil - case *ssproto.Message_SnapshotsResponse: - return msg.SnapshotsResponse, nil - default: - return nil, fmt.Errorf("unknown message type %T", msg) - } -} - -// validateMsg validates a message. -func validateMsg(pb proto.Message) error { - if pb == nil { - return errors.New("message cannot be nil") - } - switch msg := pb.(type) { - case *ssproto.ChunkRequest: - if msg.Height == 0 { - return errors.New("height cannot be 0") - } - case *ssproto.ChunkResponse: - if msg.Height == 0 { - return errors.New("height cannot be 0") - } - if msg.Missing && len(msg.Chunk) > 0 { - return errors.New("missing chunk cannot have contents") - } - if !msg.Missing && msg.Chunk == nil { - return errors.New("chunk cannot be nil") - } - case *ssproto.SnapshotsRequest: - case *ssproto.SnapshotsResponse: - if msg.Height == 0 { - return errors.New("height cannot be 0") - } - if len(msg.Hash) == 0 { - return errors.New("snapshot has no hash") - } - if msg.Chunks == 0 { - return errors.New("snapshot has no chunks") - } - default: - return fmt.Errorf("unknown message type %T", msg) - } - return nil -} +// assert Wrapper interface implementation of the state sync proto message type. +var _ p2p.Wrapper = (*ssproto.Message)(nil) diff --git a/statesync/messages_test.go b/statesync/messages_test.go index 2a05f8d79..0cc0a795b 100644 --- a/statesync/messages_test.go +++ b/statesync/messages_test.go @@ -8,84 +8,10 @@ import ( "github.com/stretchr/testify/require" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" - tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) -func TestValidateMsg(t *testing.T) { - testcases := map[string]struct { - msg proto.Message - valid bool - }{ - "nil": {nil, false}, - "unrelated": {&tmproto.Block{}, false}, - - "ChunkRequest valid": {&ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, true}, - "ChunkRequest 0 height": {&ssproto.ChunkRequest{Height: 0, Format: 1, Index: 1}, false}, - "ChunkRequest 0 format": {&ssproto.ChunkRequest{Height: 1, Format: 0, Index: 1}, true}, - "ChunkRequest 0 chunk": {&ssproto.ChunkRequest{Height: 1, Format: 1, Index: 0}, true}, - - "ChunkResponse valid": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{1}}, - true}, - "ChunkResponse 0 height": { - &ssproto.ChunkResponse{Height: 0, Format: 1, Index: 1, Chunk: []byte{1}}, - false}, - "ChunkResponse 0 format": { - &ssproto.ChunkResponse{Height: 1, Format: 0, Index: 1, Chunk: []byte{1}}, - true}, - "ChunkResponse 0 chunk": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 0, Chunk: []byte{1}}, - true}, - "ChunkResponse empty body": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{}}, - true}, - "ChunkResponse nil body": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: nil}, - false}, - "ChunkResponse missing": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true}, - true}, - "ChunkResponse missing with empty": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true, Chunk: []byte{}}, - true}, - "ChunkResponse missing with body": { - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true, Chunk: []byte{1}}, - false}, - - "SnapshotsRequest valid": {&ssproto.SnapshotsRequest{}, true}, - - "SnapshotsResponse valid": { - &ssproto.SnapshotsResponse{Height: 1, Format: 1, Chunks: 2, Hash: []byte{1}}, - true}, - "SnapshotsResponse 0 height": { - &ssproto.SnapshotsResponse{Height: 0, Format: 1, Chunks: 2, Hash: []byte{1}}, - false}, - "SnapshotsResponse 0 format": { - &ssproto.SnapshotsResponse{Height: 1, Format: 0, Chunks: 2, Hash: []byte{1}}, - true}, - "SnapshotsResponse 0 chunks": { - &ssproto.SnapshotsResponse{Height: 1, Format: 1, Hash: []byte{1}}, - false}, - "SnapshotsResponse no hash": { - &ssproto.SnapshotsResponse{Height: 1, Format: 1, Chunks: 2, Hash: []byte{}}, - false}, - } - for name, tc := range testcases { - tc := tc - t.Run(name, func(t *testing.T) { - err := validateMsg(tc.msg) - if tc.valid { - require.NoError(t, err) - } else { - require.Error(t, err) - } - }) - } -} - //nolint:lll // ignore line length func TestStateSyncVectors(t *testing.T) { - testCases := []struct { testName string msg proto.Message @@ -100,8 +26,11 @@ func TestStateSyncVectors(t *testing.T) { for _, tc := range testCases { tc := tc - bz := mustEncodeMsg(tc.msg) + msg := new(ssproto.Message) + require.NoError(t, msg.Wrap(tc.msg)) + bz, err := msg.Marshal() + require.NoError(t, err) require.Equal(t, tc.expBytes, hex.EncodeToString(bz), tc.testName) } } diff --git a/statesync/reactor.go b/statesync/reactor.go index 8d6f97018..1af831e21 100644 --- a/statesync/reactor.go +++ b/statesync/reactor.go @@ -3,10 +3,13 @@ package statesync import ( "context" "errors" + "fmt" "sort" "time" abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/libs/log" + "github.com/tendermint/tendermint/libs/service" tmsync "github.com/tendermint/tendermint/libs/sync" "github.com/tendermint/tendermint/p2p" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" @@ -15,11 +18,45 @@ import ( "github.com/tendermint/tendermint/types" ) +var ( + _ service.Service = (*Reactor)(nil) + + // ChannelShims contains a map of ChannelDescriptorShim objects, where each + // object wraps a reference to a legacy p2p ChannelDescriptor and the corresponding + // p2p proto.Message the new p2p Channel is responsible for handling. + // + // + // TODO: Remove once p2p refactor is complete. + // ref: https://github.com/tendermint/tendermint/issues/5670 + ChannelShims = map[p2p.ChannelID]*p2p.ChannelDescriptorShim{ + SnapshotChannel: { + MsgType: new(ssproto.Message), + Descriptor: &p2p.ChannelDescriptor{ + ID: byte(SnapshotChannel), + Priority: 3, + SendQueueCapacity: 10, + RecvMessageCapacity: snapshotMsgSize, + }, + }, + ChunkChannel: { + MsgType: new(ssproto.Message), + Descriptor: &p2p.ChannelDescriptor{ + ID: byte(ChunkChannel), + Priority: 1, + SendQueueCapacity: 4, + RecvMessageCapacity: chunkMsgSize, + }, + }, + } +) + const ( // SnapshotChannel exchanges snapshot metadata - SnapshotChannel = byte(0x60) + SnapshotChannel = p2p.ChannelID(0x60) + // ChunkChannel exchanges chunk contents - ChunkChannel = byte(0x61) + ChunkChannel = p2p.ChannelID(0x61) + // recentSnapshots is the number of recent snapshots to send and receive per peer. recentSnapshots = 10 ) @@ -27,189 +64,368 @@ const ( // Reactor handles state sync, both restoring snapshots for the local node and serving snapshots // for other nodes. type Reactor struct { - p2p.BaseReactor + service.BaseService - conn proxy.AppConnSnapshot - connQuery proxy.AppConnQuery - tempDir string + conn proxy.AppConnSnapshot + connQuery proxy.AppConnQuery + tempDir string + snapshotCh *p2p.Channel + chunkCh *p2p.Channel + peerUpdates *p2p.PeerUpdatesCh + closeCh chan struct{} - // This will only be set when a state sync is in progress. It is used to feed received - // snapshots and chunks into the sync. + // This will only be set when a state sync is in progress. It is used to feed + // received snapshots and chunks into the sync. mtx tmsync.RWMutex syncer *syncer } -// NewReactor creates a new state sync reactor. -func NewReactor(conn proxy.AppConnSnapshot, connQuery proxy.AppConnQuery, tempDir string) *Reactor { +// NewReactor returns a reference to a new state sync reactor, which implements +// the service.Service interface. It accepts a logger, connections for snapshots +// and querying, references to p2p Channels and a channel to listen for peer +// updates on. Note, the reactor will close all p2p Channels when stopping. +func NewReactor( + logger log.Logger, + conn proxy.AppConnSnapshot, + connQuery proxy.AppConnQuery, + snapshotCh, chunkCh *p2p.Channel, + peerUpdates *p2p.PeerUpdatesCh, + tempDir string, +) *Reactor { r := &Reactor{ - conn: conn, - connQuery: connQuery, + conn: conn, + connQuery: connQuery, + snapshotCh: snapshotCh, + chunkCh: chunkCh, + peerUpdates: peerUpdates, + closeCh: make(chan struct{}), + tempDir: tempDir, } - r.BaseReactor = *p2p.NewBaseReactor("StateSync", r) - return r -} -// GetChannels implements p2p.Reactor. -func (r *Reactor) GetChannels() []*p2p.ChannelDescriptor { - return []*p2p.ChannelDescriptor{ - { - ID: SnapshotChannel, - Priority: 3, - SendQueueCapacity: 10, - RecvMessageCapacity: snapshotMsgSize, - }, - { - ID: ChunkChannel, - Priority: 1, - SendQueueCapacity: 4, - RecvMessageCapacity: chunkMsgSize, - }, - } + r.BaseService = *service.NewBaseService(logger, "StateSync", r) + return r } -// OnStart implements p2p.Reactor. +// OnStart starts separate go routines for each p2p Channel and listens for +// envelopes on each. In addition, it also listens for peer updates and handles +// messages on that p2p channel accordingly. The caller must be sure to execute +// OnStop to ensure the outbound p2p Channels are closed. No error is returned. func (r *Reactor) OnStart() error { + // Listen for envelopes on the snapshot p2p Channel in a separate go-routine + // as to not block or cause IO contention with the chunk p2p Channel. Note, + // we do not launch a go-routine to handle individual envelopes as to not + // have to deal with bounding workers or pools. + go r.processSnapshotCh() + + // Listen for envelopes on the chunk p2p Channel in a separate go-routine + // as to not block or cause IO contention with the snapshot p2p Channel. Note, + // we do not launch a go-routine to handle individual envelopes as to not + // have to deal with bounding workers or pools. + go r.processChunkCh() + + go r.processPeerUpdates() + return nil } -// AddPeer implements p2p.Reactor. -func (r *Reactor) AddPeer(peer p2p.Peer) { - r.mtx.RLock() - defer r.mtx.RUnlock() - if r.syncer != nil { - r.syncer.AddPeer(peer) - } -} +// OnStop stops the reactor by signaling to all spawned goroutines to exit and +// blocking until they all exit. +func (r *Reactor) OnStop() { + // Close closeCh to signal to all spawned goroutines to gracefully exit. All + // p2p Channels should execute Close(). + close(r.closeCh) -// RemovePeer implements p2p.Reactor. -func (r *Reactor) RemovePeer(peer p2p.Peer, reason interface{}) { - r.mtx.RLock() - defer r.mtx.RUnlock() - if r.syncer != nil { - r.syncer.RemovePeer(peer) - } + // Wait for all p2p Channels to be closed before returning. This ensures we + // can easily reason about synchronization of all p2p Channels and ensure no + // panics will occur. + <-r.snapshotCh.Done() + <-r.chunkCh.Done() + <-r.peerUpdates.Done() } -// Receive implements p2p.Reactor. -// XXX: do not call any methods that can block or incur heavy processing. -// https://github.com/tendermint/tendermint/issues/2888 -func (r *Reactor) Receive(chID byte, src p2p.Peer, msgBytes []byte) { - if !r.IsRunning() { - return - } - - msg, err := decodeMsg(msgBytes) - if err != nil { - r.Logger.Error("Error decoding message", "src", src, "chId", chID, "err", err) - r.Switch.StopPeerForError(src, err) - return - } - err = validateMsg(msg) - if err != nil { - r.Logger.Error("Invalid message", "peer", src, "msg", msg, "err", err) - r.Switch.StopPeerForError(src, err) - return - } +// handleSnapshotMessage handles enevelopes sent from peers on the +// SnapshotChannel. It returns an error only if the Envelope.Message is unknown +// for this channel. This should never be called outside of handleMessage. +func (r *Reactor) handleSnapshotMessage(envelope p2p.Envelope) error { + switch msg := envelope.Message.(type) { + case *ssproto.SnapshotsRequest: + snapshots, err := r.recentSnapshots(recentSnapshots) + if err != nil { + r.Logger.Error("failed to fetch snapshots", "err", err) + return nil + } - switch chID { - case SnapshotChannel: - switch msg := msg.(type) { - case *ssproto.SnapshotsRequest: - snapshots, err := r.recentSnapshots(recentSnapshots) - if err != nil { - r.Logger.Error("Failed to fetch snapshots", "err", err) - return - } - for _, snapshot := range snapshots { - r.Logger.Debug("Advertising snapshot", "height", snapshot.Height, - "format", snapshot.Format, "peer", src.ID()) - src.Send(chID, mustEncodeMsg(&ssproto.SnapshotsResponse{ + for _, snapshot := range snapshots { + r.Logger.Debug( + "advertising snapshot", + "height", snapshot.Height, + "format", snapshot.Format, + "peer", envelope.From.String(), + ) + r.snapshotCh.Out() <- p2p.Envelope{ + To: envelope.From, + Message: &ssproto.SnapshotsResponse{ Height: snapshot.Height, Format: snapshot.Format, Chunks: snapshot.Chunks, Hash: snapshot.Hash, Metadata: snapshot.Metadata, - })) + }, } + } - case *ssproto.SnapshotsResponse: - r.mtx.RLock() - defer r.mtx.RUnlock() - if r.syncer == nil { - r.Logger.Debug("Received unexpected snapshot, no state sync in progress") - return - } - r.Logger.Debug("Received snapshot", "height", msg.Height, "format", msg.Format, "peer", src.ID()) - _, err := r.syncer.AddSnapshot(src, &snapshot{ - Height: msg.Height, - Format: msg.Format, - Chunks: msg.Chunks, - Hash: msg.Hash, - Metadata: msg.Metadata, - }) - if err != nil { - r.Logger.Error("Failed to add snapshot", "height", msg.Height, "format", msg.Format, - "peer", src.ID(), "err", err) - return - } + case *ssproto.SnapshotsResponse: + r.mtx.RLock() + defer r.mtx.RUnlock() - default: - r.Logger.Error("Received unknown message %T", msg) + if r.syncer == nil { + r.Logger.Debug("received unexpected snapshot; no state sync in progress") + return nil } - case ChunkChannel: - switch msg := msg.(type) { - case *ssproto.ChunkRequest: - r.Logger.Debug("Received chunk request", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) - resp, err := r.conn.LoadSnapshotChunkSync(context.Background(), abci.RequestLoadSnapshotChunk{ - Height: msg.Height, - Format: msg.Format, - Chunk: msg.Index, - }) - if err != nil { - r.Logger.Error("Failed to load chunk", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "err", err) - return - } - r.Logger.Debug("Sending chunk", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) - src.Send(ChunkChannel, mustEncodeMsg(&ssproto.ChunkResponse{ + r.Logger.Debug( + "received snapshot", + "height", msg.Height, + "format", msg.Format, + "peer", envelope.From.String(), + ) + _, err := r.syncer.AddSnapshot(envelope.From, &snapshot{ + Height: msg.Height, + Format: msg.Format, + Chunks: msg.Chunks, + Hash: msg.Hash, + Metadata: msg.Metadata, + }) + if err != nil { + r.Logger.Error( + "failed to add snapshot", + "height", msg.Height, + "format", msg.Format, + "err", err, + "channel", r.snapshotCh.ID, + ) + return nil + } + + default: + r.Logger.Error("received unknown message", "msg", msg, "peer", envelope.From.String()) + return fmt.Errorf("received unknown message: %T", msg) + } + + return nil +} + +// handleChunkMessage handles enevelopes sent from peers on the ChunkChannel. +// It returns an error only if the Envelope.Message is unknown for this channel. +// This should never be called outside of handleMessage. +func (r *Reactor) handleChunkMessage(envelope p2p.Envelope) error { + switch msg := envelope.Message.(type) { + case *ssproto.ChunkRequest: + r.Logger.Debug( + "received chunk request", + "height", msg.Height, + "format", msg.Format, + "chunk", msg.Index, + "peer", envelope.From.String(), + ) + resp, err := r.conn.LoadSnapshotChunkSync(context.Background(), abci.RequestLoadSnapshotChunk{ + Height: msg.Height, + Format: msg.Format, + Chunk: msg.Index, + }) + if err != nil { + r.Logger.Error( + "failed to load chunk", + "height", msg.Height, + "format", msg.Format, + "chunk", msg.Index, + "err", err, + "peer", envelope.From.String(), + ) + return nil + } + + r.Logger.Debug( + "sending chunk", + "height", msg.Height, + "format", msg.Format, + "chunk", msg.Index, + "peer", envelope.From.String(), + ) + r.chunkCh.Out() <- p2p.Envelope{ + To: envelope.From, + Message: &ssproto.ChunkResponse{ Height: msg.Height, Format: msg.Format, Index: msg.Index, Chunk: resp.Chunk, Missing: resp.Chunk == nil, - })) - - case *ssproto.ChunkResponse: - r.mtx.RLock() - defer r.mtx.RUnlock() - if r.syncer == nil { - r.Logger.Debug("Received unexpected chunk, no state sync in progress", "peer", src.ID()) - return + }, + } + + case *ssproto.ChunkResponse: + r.mtx.RLock() + defer r.mtx.RUnlock() + + if r.syncer == nil { + r.Logger.Debug("received unexpected chunk; no state sync in progress", "peer", envelope.From.String()) + return nil + } + + r.Logger.Debug( + "received chunk; adding to sync", + "height", msg.Height, + "format", msg.Format, + "chunk", msg.Index, + "peer", envelope.From.String(), + ) + _, err := r.syncer.AddChunk(&chunk{ + Height: msg.Height, + Format: msg.Format, + Index: msg.Index, + Chunk: msg.Chunk, + Sender: envelope.From, + }) + if err != nil { + r.Logger.Error( + "failed to add chunk", + "height", msg.Height, + "format", msg.Format, + "chunk", msg.Index, + "err", err, + "peer", envelope.From.String(), + ) + return nil + } + + default: + r.Logger.Error("received unknown message", "msg", msg, "peer", envelope.From.String()) + return fmt.Errorf("received unknown message: %T", msg) + } + + return nil +} + +// handleMessage handles an Envelope sent from a peer on a specific p2p Channel. +// It will handle errors and any possible panics gracefully. A caller can handle +// any error returned by sending a PeerError on the respective channel. +func (r *Reactor) handleMessage(chID p2p.ChannelID, envelope p2p.Envelope) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic in processing message: %v", e) + r.Logger.Error("recovering from processing message panic", "err", err) + } + }() + + switch chID { + case SnapshotChannel: + err = r.handleSnapshotMessage(envelope) + + case ChunkChannel: + err = r.handleChunkMessage(envelope) + + default: + err = fmt.Errorf("unknown channel ID (%d) for envelope (%v)", chID, envelope) + } + + return err +} + +// processSnapshotCh initiates a blocking process where we listen for and handle +// envelopes on the SnapshotChannel. Any error encountered during message +// execution will result in a PeerError being sent on the SnapshotChannel. When +// the reactor is stopped, we will catch the singal and close the p2p Channel +// gracefully. +func (r *Reactor) processSnapshotCh() { + defer r.snapshotCh.Close() + + for { + select { + case envelope := <-r.snapshotCh.In(): + if err := r.handleMessage(r.snapshotCh.ID(), envelope); err != nil { + r.snapshotCh.Error() <- p2p.PeerError{ + PeerID: envelope.From, + Err: err, + Severity: p2p.PeerErrorSeverityLow, + } } - r.Logger.Debug("Received chunk, adding to sync", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "peer", src.ID()) - _, err := r.syncer.AddChunk(&chunk{ - Height: msg.Height, - Format: msg.Format, - Index: msg.Index, - Chunk: msg.Chunk, - Sender: src.ID(), - }) - if err != nil { - r.Logger.Error("Failed to add chunk", "height", msg.Height, "format", msg.Format, - "chunk", msg.Index, "err", err) - return + + case <-r.closeCh: + r.Logger.Debug("stopped listening on snapshot channel; closing...") + return + } + } +} + +// processChunkCh initiates a blocking process where we listen for and handle +// envelopes on the ChunkChannel. Any error encountered during message +// execution will result in a PeerError being sent on the ChunkChannel. When +// the reactor is stopped, we will catch the singal and close the p2p Channel +// gracefully. +func (r *Reactor) processChunkCh() { + defer r.chunkCh.Close() + + for { + select { + case envelope := <-r.chunkCh.In(): + if err := r.handleMessage(r.chunkCh.ID(), envelope); err != nil { + r.chunkCh.Error() <- p2p.PeerError{ + PeerID: envelope.From, + Err: err, + Severity: p2p.PeerErrorSeverityLow, + } } - default: - r.Logger.Error("Received unknown message %T", msg) + case <-r.closeCh: + r.Logger.Debug("stopped listening on chunk channel; closing...") + return } + } +} - default: - r.Logger.Error("Received message on invalid channel %x", chID) +// processPeerUpdate processes a PeerUpdate, returning an error upon failing to +// handle the PeerUpdate or if a panic is recovered. +func (r *Reactor) processPeerUpdate(peerUpdate p2p.PeerUpdate) (err error) { + defer func() { + if e := recover(); e != nil { + err = fmt.Errorf("panic in processing peer update: %v", e) + r.Logger.Error("recovering from processing peer update panic", "err", err) + } + }() + + r.Logger.Debug("received peer update", "peer", peerUpdate.PeerID.String(), "status", peerUpdate.Status) + + r.mtx.RLock() + defer r.mtx.RUnlock() + + if r.syncer != nil { + switch peerUpdate.Status { + case p2p.PeerStatusNew, p2p.PeerStatusUp: + r.syncer.AddPeer(peerUpdate.PeerID) + + case p2p.PeerStatusDown, p2p.PeerStatusRemoved, p2p.PeerStatusBanned: + r.syncer.RemovePeer(peerUpdate.PeerID) + } + } + + return err +} + +// processPeerUpdates initiates a blocking process where we listen for and handle +// PeerUpdate messages. When the reactor is stopped, we will catch the singal and +// close the p2p PeerUpdatesCh gracefully. +func (r *Reactor) processPeerUpdates() { + defer r.peerUpdates.Close() + + for { + select { + case peerUpdate := <-r.peerUpdates.Updates(): + _ = r.processPeerUpdate(peerUpdate) + + case <-r.closeCh: + r.Logger.Debug("stopped listening on peer updates channel; closing...") + return + } } } @@ -219,9 +435,11 @@ func (r *Reactor) recentSnapshots(n uint32) ([]*snapshot, error) { if err != nil { return nil, err } + sort.Slice(resp.Snapshots, func(i, j int) bool { a := resp.Snapshots[i] b := resp.Snapshots[j] + switch { case a.Height > b.Height: return true @@ -231,11 +449,13 @@ func (r *Reactor) recentSnapshots(n uint32) ([]*snapshot, error) { return false } }) + snapshots := make([]*snapshot, 0, n) for i, s := range resp.Snapshots { if i >= recentSnapshots { break } + snapshots = append(snapshots, &snapshot{ Height: s.Height, Format: s.Format, @@ -244,6 +464,7 @@ func (r *Reactor) recentSnapshots(n uint32) ([]*snapshot, error) { Metadata: s.Metadata, }) } + return snapshots, nil } @@ -255,16 +476,22 @@ func (r *Reactor) Sync(stateProvider StateProvider, discoveryTime time.Duration) r.mtx.Unlock() return sm.State{}, nil, errors.New("a state sync is already in progress") } - r.syncer = newSyncer(r.Logger, r.conn, r.connQuery, stateProvider, r.tempDir) + + r.syncer = newSyncer(r.Logger, r.conn, r.connQuery, stateProvider, r.snapshotCh.Out(), r.chunkCh.Out(), r.tempDir) r.mtx.Unlock() - // Request snapshots from all currently connected peers - r.Logger.Debug("Requesting snapshots from known peers") - r.Switch.Broadcast(SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})) + // request snapshots from all currently connected peers + r.Logger.Debug("requesting snapshots from known peers") + r.snapshotCh.Out() <- p2p.Envelope{ + Broadcast: true, + Message: &ssproto.SnapshotsRequest{}, + } state, commit, err := r.syncer.SyncAny(discoveryTime) + r.mtx.Lock() r.syncer = nil r.mtx.Unlock() + return state, commit, err } diff --git a/statesync/reactor_test.go b/statesync/reactor_test.go index 72062ca9d..9d527f08a 100644 --- a/statesync/reactor_test.go +++ b/statesync/reactor_test.go @@ -5,18 +5,132 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" abci "github.com/tendermint/tendermint/abci/types" + "github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/p2p" - p2pmocks "github.com/tendermint/tendermint/p2p/mocks" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" proxymocks "github.com/tendermint/tendermint/proxy/mocks" + "github.com/tendermint/tendermint/statesync/mocks" ) -func TestReactor_Receive_ChunkRequest(t *testing.T) { +type reactorTestSuite struct { + reactor *Reactor + syncer *syncer + + conn *proxymocks.AppConnSnapshot + connQuery *proxymocks.AppConnQuery + stateProvider *mocks.StateProvider + + snapshotChannel *p2p.Channel + snapshotInCh chan p2p.Envelope + snapshotOutCh chan p2p.Envelope + snapshotPeerErrCh chan p2p.PeerError + + chunkChannel *p2p.Channel + chunkInCh chan p2p.Envelope + chunkOutCh chan p2p.Envelope + chunkPeerErrCh chan p2p.PeerError + + peerUpdates *p2p.PeerUpdatesCh +} + +func setup( + t *testing.T, + conn *proxymocks.AppConnSnapshot, + connQuery *proxymocks.AppConnQuery, + stateProvider *mocks.StateProvider, + chBuf uint, +) *reactorTestSuite { + t.Helper() + + if conn == nil { + conn = &proxymocks.AppConnSnapshot{} + } + if connQuery == nil { + connQuery = &proxymocks.AppConnQuery{} + } + if stateProvider == nil { + stateProvider = &mocks.StateProvider{} + } + + rts := &reactorTestSuite{ + snapshotInCh: make(chan p2p.Envelope, chBuf), + snapshotOutCh: make(chan p2p.Envelope, chBuf), + snapshotPeerErrCh: make(chan p2p.PeerError, chBuf), + chunkInCh: make(chan p2p.Envelope, chBuf), + chunkOutCh: make(chan p2p.Envelope, chBuf), + chunkPeerErrCh: make(chan p2p.PeerError, chBuf), + peerUpdates: p2p.NewPeerUpdates(), + conn: conn, + connQuery: connQuery, + stateProvider: stateProvider, + } + + rts.snapshotChannel = p2p.NewChannel( + SnapshotChannel, + new(ssproto.Message), + rts.snapshotInCh, + rts.snapshotOutCh, + rts.snapshotPeerErrCh, + ) + + rts.chunkChannel = p2p.NewChannel( + ChunkChannel, + new(ssproto.Message), + rts.chunkInCh, + rts.chunkOutCh, + rts.chunkPeerErrCh, + ) + + rts.reactor = NewReactor( + log.NewNopLogger(), + conn, + connQuery, + rts.snapshotChannel, + rts.chunkChannel, + rts.peerUpdates, + "", + ) + + rts.syncer = newSyncer( + log.NewNopLogger(), + conn, + connQuery, + stateProvider, + rts.snapshotOutCh, + rts.chunkOutCh, + "", + ) + + require.NoError(t, rts.reactor.Start()) + require.True(t, rts.reactor.IsRunning()) + + t.Cleanup(func() { + require.NoError(t, rts.reactor.Stop()) + require.False(t, rts.reactor.IsRunning()) + }) + + return rts +} + +func TestReactor_ChunkRequest_InvalidRequest(t *testing.T) { + rts := setup(t, nil, nil, nil, 2) + + rts.chunkInCh <- p2p.Envelope{ + From: p2p.PeerID{0xAA}, + Message: &ssproto.SnapshotsRequest{}, + } + + response := <-rts.chunkPeerErrCh + require.Error(t, response.Err) + require.Empty(t, rts.chunkOutCh) + require.Contains(t, response.Err.Error(), "received unknown message") + require.Equal(t, p2p.PeerID{0xAA}, response.PeerID) +} + +func TestReactor_ChunkRequest(t *testing.T) { testcases := map[string]struct { request *ssproto.ChunkRequest chunk []byte @@ -25,22 +139,30 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { "chunk is returned": { &ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, []byte{1, 2, 3}, - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{1, 2, 3}}}, - "empty chunk is returned, as nil": { + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{1, 2, 3}}, + }, + "empty chunk is returned, as empty": { &ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, []byte{}, - &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: nil}}, + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Chunk: []byte{}}, + }, "nil (missing) chunk is returned as missing": { &ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, nil, &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true}, }, + "invalid request": { + &ssproto.ChunkRequest{Height: 1, Format: 1, Index: 1}, + nil, + &ssproto.ChunkResponse{Height: 1, Format: 1, Index: 1, Missing: true}, + }, } for name, tc := range testcases { tc := tc + t.Run(name, func(t *testing.T) { - // Mock ABCI connection to return local snapshots + // mock ABCI connection to return local snapshots conn := &proxymocks.AppConnSnapshot{} conn.On("LoadSnapshotChunkSync", context.Background(), abci.RequestLoadSnapshotChunk{ Height: tc.request.Height, @@ -48,39 +170,38 @@ func TestReactor_Receive_ChunkRequest(t *testing.T) { Chunk: tc.request.Index, }).Return(&abci.ResponseLoadSnapshotChunk{Chunk: tc.chunk}, nil) - // Mock peer to store response, if found - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID("id")) - var response *ssproto.ChunkResponse - if tc.expectResponse != nil { - peer.On("Send", ChunkChannel, mock.Anything).Run(func(args mock.Arguments) { - msg, err := decodeMsg(args[1].([]byte)) - require.NoError(t, err) - response = msg.(*ssproto.ChunkResponse) - }).Return(true) - } + rts := setup(t, conn, nil, nil, 2) - // Start a reactor and send a ssproto.ChunkRequest, then wait for and check response - r := NewReactor(conn, nil, "") - err := r.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := r.Stop(); err != nil { - t.Error(err) - } - }) + rts.chunkInCh <- p2p.Envelope{ + From: p2p.PeerID{0xAA}, + Message: tc.request, + } - r.Receive(ChunkChannel, peer, mustEncodeMsg(tc.request)) - time.Sleep(100 * time.Millisecond) - assert.Equal(t, tc.expectResponse, response) + response := <-rts.chunkOutCh + require.Equal(t, tc.expectResponse, response.Message) + require.Empty(t, rts.chunkOutCh) conn.AssertExpectations(t) - peer.AssertExpectations(t) }) } } -func TestReactor_Receive_SnapshotsRequest(t *testing.T) { +func TestReactor_SnapshotsRequest_InvalidRequest(t *testing.T) { + rts := setup(t, nil, nil, nil, 2) + + rts.snapshotInCh <- p2p.Envelope{ + From: p2p.PeerID{0xAA}, + Message: &ssproto.ChunkRequest{}, + } + + response := <-rts.snapshotPeerErrCh + require.Error(t, response.Err) + require.Empty(t, rts.snapshotOutCh) + require.Contains(t, response.Err.Error(), "received unknown message") + require.Equal(t, p2p.PeerID{0xAA}, response.PeerID) +} + +func TestReactor_SnapshotsRequest(t *testing.T) { testcases := map[string]struct { snapshots []*abci.Snapshot expectResponses []*ssproto.SnapshotsResponse @@ -118,41 +239,48 @@ func TestReactor_Receive_SnapshotsRequest(t *testing.T) { for name, tc := range testcases { tc := tc + t.Run(name, func(t *testing.T) { - // Mock ABCI connection to return local snapshots + // mock ABCI connection to return local snapshots conn := &proxymocks.AppConnSnapshot{} conn.On("ListSnapshotsSync", context.Background(), abci.RequestListSnapshots{}).Return(&abci.ResponseListSnapshots{ Snapshots: tc.snapshots, }, nil) - // Mock peer to catch responses and store them in a slice - responses := []*ssproto.SnapshotsResponse{} - peer := &p2pmocks.Peer{} - if len(tc.expectResponses) > 0 { - peer.On("ID").Return(p2p.ID("id")) - peer.On("Send", SnapshotChannel, mock.Anything).Run(func(args mock.Arguments) { - msg, err := decodeMsg(args[1].([]byte)) - require.NoError(t, err) - responses = append(responses, msg.(*ssproto.SnapshotsResponse)) - }).Return(true) + rts := setup(t, conn, nil, nil, 100) + + rts.snapshotInCh <- p2p.Envelope{ + From: p2p.PeerID{0xAA}, + Message: &ssproto.SnapshotsRequest{}, } - // Start a reactor and send a SnapshotsRequestMessage, then wait for and check responses - r := NewReactor(conn, nil, "") - err := r.Start() - require.NoError(t, err) - t.Cleanup(func() { - if err := r.Stop(); err != nil { - t.Error(err) - } - }) + if len(tc.expectResponses) > 0 { + retryUntil(t, func() bool { return len(rts.snapshotOutCh) == len(tc.expectResponses) }, time.Second) + } - r.Receive(SnapshotChannel, peer, mustEncodeMsg(&ssproto.SnapshotsRequest{})) - time.Sleep(100 * time.Millisecond) - assert.Equal(t, tc.expectResponses, responses) + responses := make([]*ssproto.SnapshotsResponse, len(tc.expectResponses)) + for i := 0; i < len(tc.expectResponses); i++ { + e := <-rts.snapshotOutCh + responses[i] = e.Message.(*ssproto.SnapshotsResponse) + } - conn.AssertExpectations(t) - peer.AssertExpectations(t) + require.Equal(t, tc.expectResponses, responses) + require.Empty(t, rts.snapshotOutCh) }) } } + +// retryUntil will continue to evaluate fn and will return successfully when true +// or fail when the timeout is reached. +func retryUntil(t *testing.T, fn func() bool, timeout time.Duration) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + if fn() { + return + } + + require.NoError(t, ctx.Err()) + } +} diff --git a/statesync/snapshots.go b/statesync/snapshots.go index 3eca5c0be..ab8fa8e2e 100644 --- a/statesync/snapshots.go +++ b/statesync/snapshots.go @@ -1,6 +1,7 @@ package statesync import ( + "bytes" "context" "crypto/sha256" "fmt" @@ -46,16 +47,16 @@ type snapshotPool struct { tmsync.Mutex snapshots map[snapshotKey]*snapshot - snapshotPeers map[snapshotKey]map[p2p.ID]p2p.Peer + snapshotPeers map[snapshotKey]map[string]p2p.PeerID // indexes for fast searches formatIndex map[uint32]map[snapshotKey]bool heightIndex map[uint64]map[snapshotKey]bool - peerIndex map[p2p.ID]map[snapshotKey]bool + peerIndex map[string]map[snapshotKey]bool // blacklists for rejected items formatBlacklist map[uint32]bool - peerBlacklist map[p2p.ID]bool + peerBlacklist map[string]bool snapshotBlacklist map[snapshotKey]bool } @@ -64,20 +65,21 @@ func newSnapshotPool(stateProvider StateProvider) *snapshotPool { return &snapshotPool{ stateProvider: stateProvider, snapshots: make(map[snapshotKey]*snapshot), - snapshotPeers: make(map[snapshotKey]map[p2p.ID]p2p.Peer), + snapshotPeers: make(map[snapshotKey]map[string]p2p.PeerID), formatIndex: make(map[uint32]map[snapshotKey]bool), heightIndex: make(map[uint64]map[snapshotKey]bool), - peerIndex: make(map[p2p.ID]map[snapshotKey]bool), + peerIndex: make(map[string]map[snapshotKey]bool), formatBlacklist: make(map[uint32]bool), - peerBlacklist: make(map[p2p.ID]bool), + peerBlacklist: make(map[string]bool), snapshotBlacklist: make(map[snapshotKey]bool), } } -// Add adds a snapshot to the pool, unless the peer has already sent recentSnapshots snapshots. It -// returns true if this was a new, non-blacklisted snapshot. The snapshot height is verified using -// the light client, and the expected app hash is set for the snapshot. -func (p *snapshotPool) Add(peer p2p.Peer, snapshot *snapshot) (bool, error) { +// Add adds a snapshot to the pool, unless the peer has already sent recentSnapshots +// snapshots. It returns true if this was a new, non-blacklisted snapshot. The +// snapshot height is verified using the light client, and the expected app hash +// is set for the snapshot. +func (p *snapshotPool) Add(peer p2p.PeerID, snapshot *snapshot) (bool, error) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -94,23 +96,23 @@ func (p *snapshotPool) Add(peer p2p.Peer, snapshot *snapshot) (bool, error) { switch { case p.formatBlacklist[snapshot.Format]: return false, nil - case p.peerBlacklist[peer.ID()]: + case p.peerBlacklist[peer.String()]: return false, nil case p.snapshotBlacklist[key]: return false, nil - case len(p.peerIndex[peer.ID()]) >= recentSnapshots: + case len(p.peerIndex[peer.String()]) >= recentSnapshots: return false, nil } if p.snapshotPeers[key] == nil { - p.snapshotPeers[key] = make(map[p2p.ID]p2p.Peer) + p.snapshotPeers[key] = make(map[string]p2p.PeerID) } - p.snapshotPeers[key][peer.ID()] = peer + p.snapshotPeers[key][peer.String()] = peer - if p.peerIndex[peer.ID()] == nil { - p.peerIndex[peer.ID()] = make(map[snapshotKey]bool) + if p.peerIndex[peer.String()] == nil { + p.peerIndex[peer.String()] = make(map[snapshotKey]bool) } - p.peerIndex[peer.ID()][key] = true + p.peerIndex[peer.String()][key] = true if p.snapshots[key] != nil { return false, nil @@ -140,7 +142,7 @@ func (p *snapshotPool) Best() *snapshot { } // GetPeer returns a random peer for a snapshot, if any. -func (p *snapshotPool) GetPeer(snapshot *snapshot) p2p.Peer { +func (p *snapshotPool) GetPeer(snapshot *snapshot) p2p.PeerID { peers := p.GetPeers(snapshot) if len(peers) == 0 { return nil @@ -149,19 +151,22 @@ func (p *snapshotPool) GetPeer(snapshot *snapshot) p2p.Peer { } // GetPeers returns the peers for a snapshot. -func (p *snapshotPool) GetPeers(snapshot *snapshot) []p2p.Peer { +func (p *snapshotPool) GetPeers(snapshot *snapshot) []p2p.PeerID { key := snapshot.Key() + p.Lock() defer p.Unlock() - peers := make([]p2p.Peer, 0, len(p.snapshotPeers[key])) + peers := make([]p2p.PeerID, 0, len(p.snapshotPeers[key])) for _, peer := range p.snapshotPeers[key] { peers = append(peers, peer) } + // sort results, for testability (otherwise order is random, so tests randomly fail) sort.Slice(peers, func(a int, b int) bool { - return peers[a].ID() < peers[b].ID() + return bytes.Compare(peers[a], peers[b]) < 0 }) + return peers } @@ -222,33 +227,35 @@ func (p *snapshotPool) RejectFormat(format uint32) { } // RejectPeer rejects a peer. It will never be used again. -func (p *snapshotPool) RejectPeer(peerID p2p.ID) { - if peerID == "" { +func (p *snapshotPool) RejectPeer(peerID p2p.PeerID) { + if len(peerID) == 0 { return } + p.Lock() defer p.Unlock() p.removePeer(peerID) - p.peerBlacklist[peerID] = true + p.peerBlacklist[peerID.String()] = true } // RemovePeer removes a peer from the pool, and any snapshots that no longer have peers. -func (p *snapshotPool) RemovePeer(peerID p2p.ID) { +func (p *snapshotPool) RemovePeer(peerID p2p.PeerID) { p.Lock() defer p.Unlock() p.removePeer(peerID) } // removePeer removes a peer. The caller must hold the mutex lock. -func (p *snapshotPool) removePeer(peerID p2p.ID) { - for key := range p.peerIndex[peerID] { - delete(p.snapshotPeers[key], peerID) +func (p *snapshotPool) removePeer(peerID p2p.PeerID) { + for key := range p.peerIndex[peerID.String()] { + delete(p.snapshotPeers[key], peerID.String()) if len(p.snapshotPeers[key]) == 0 { p.removeSnapshot(key) } } - delete(p.peerIndex, peerID) + + delete(p.peerIndex, peerID.String()) } // removeSnapshot removes a snapshot. The caller must hold the mutex lock. diff --git a/statesync/snapshots_test.go b/statesync/snapshots_test.go index 588c0ac31..866267fb7 100644 --- a/statesync/snapshots_test.go +++ b/statesync/snapshots_test.go @@ -3,12 +3,10 @@ package statesync import ( "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/p2p" - p2pmocks "github.com/tendermint/tendermint/p2p/mocks" "github.com/tendermint/tendermint/statesync/mocks" ) @@ -35,7 +33,7 @@ func TestSnapshot_Key(t *testing.T) { before := s.Key() tc.modify(&s) after := s.Key() - assert.NotEqual(t, before, after) + require.NotEqual(t, before, after) }) } } @@ -44,36 +42,34 @@ func TestSnapshotPool_Add(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, uint64(1)).Return([]byte("app_hash"), nil) - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID("id")) + peerID := p2p.PeerID{0xAA} // Adding to the pool should work pool := newSnapshotPool(stateProvider) - added, err := pool.Add(peer, &snapshot{ + added, err := pool.Add(peerID, &snapshot{ Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}, }) require.NoError(t, err) - assert.True(t, added) + require.True(t, added) // Adding again from a different peer should return false - otherPeer := &p2pmocks.Peer{} - otherPeer.On("ID").Return(p2p.ID("other")) - added, err = pool.Add(peer, &snapshot{ + otherPeerID := p2p.PeerID{0xBB} + added, err = pool.Add(otherPeerID, &snapshot{ Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}, }) require.NoError(t, err) - assert.False(t, added) + require.False(t, added) // The pool should have populated the snapshot with the trusted app hash snapshot := pool.Best() require.NotNil(t, snapshot) - assert.Equal(t, []byte("app_hash"), snapshot.trustedAppHash) + require.Equal(t, []byte("app_hash"), snapshot.trustedAppHash) stateProvider.AssertExpectations(t) } @@ -84,16 +80,17 @@ func TestSnapshotPool_GetPeer(t *testing.T) { pool := newSnapshotPool(stateProvider) s := &snapshot{Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}} - peerA := &p2pmocks.Peer{} - peerA.On("ID").Return(p2p.ID("a")) - peerB := &p2pmocks.Peer{} - peerB.On("ID").Return(p2p.ID("b")) - _, err := pool.Add(peerA, s) + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} + + _, err := pool.Add(peerAID, s) require.NoError(t, err) - _, err = pool.Add(peerB, s) + + _, err = pool.Add(peerBID, s) require.NoError(t, err) - _, err = pool.Add(peerA, &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{1}}) + + _, err = pool.Add(peerAID, &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{1}}) require.NoError(t, err) // GetPeer currently picks a random peer, so lets run it until we've seen both. @@ -101,17 +98,17 @@ func TestSnapshotPool_GetPeer(t *testing.T) { seenB := false for !seenA || !seenB { peer := pool.GetPeer(s) - switch peer.ID() { - case p2p.ID("a"): + if peer.Equal(peerAID) { seenA = true - case p2p.ID("b"): + } + if peer.Equal(peerBID) { seenB = true } } // GetPeer should return nil for an unknown snapshot peer := pool.GetPeer(&snapshot{Height: 9, Format: 9}) - assert.Nil(t, peer) + require.Nil(t, peer) } func TestSnapshotPool_GetPeers(t *testing.T) { @@ -120,22 +117,23 @@ func TestSnapshotPool_GetPeers(t *testing.T) { pool := newSnapshotPool(stateProvider) s := &snapshot{Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}} - peerA := &p2pmocks.Peer{} - peerA.On("ID").Return(p2p.ID("a")) - peerB := &p2pmocks.Peer{} - peerB.On("ID").Return(p2p.ID("b")) - _, err := pool.Add(peerA, s) + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} + + _, err := pool.Add(peerAID, s) require.NoError(t, err) - _, err = pool.Add(peerB, s) + + _, err = pool.Add(peerBID, s) require.NoError(t, err) - _, err = pool.Add(peerA, &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{2}}) + + _, err = pool.Add(peerAID, &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{2}}) require.NoError(t, err) peers := pool.GetPeers(s) - assert.Len(t, peers, 2) - assert.EqualValues(t, "a", peers[0].ID()) - assert.EqualValues(t, "b", peers[1].ID()) + require.Len(t, peers, 2) + require.Equal(t, peerAID, peers[0]) + require.EqualValues(t, peerBID, peers[1]) } func TestSnapshotPool_Ranked_Best(t *testing.T) { @@ -150,28 +148,30 @@ func TestSnapshotPool_Ranked_Best(t *testing.T) { snapshot *snapshot peers []string }{ - {&snapshot{Height: 2, Format: 2, Chunks: 4, Hash: []byte{1, 3}}, []string{"a", "b", "c"}}, - {&snapshot{Height: 2, Format: 2, Chunks: 5, Hash: []byte{1, 2}}, []string{"a"}}, - {&snapshot{Height: 2, Format: 1, Chunks: 3, Hash: []byte{1, 2}}, []string{"a", "b"}}, - {&snapshot{Height: 1, Format: 2, Chunks: 5, Hash: []byte{1, 2}}, []string{"a", "b"}}, - {&snapshot{Height: 1, Format: 1, Chunks: 4, Hash: []byte{1, 2}}, []string{"a", "b", "c"}}, + {&snapshot{Height: 2, Format: 2, Chunks: 4, Hash: []byte{1, 3}}, []string{"AA", "BB", "CC"}}, + {&snapshot{Height: 2, Format: 2, Chunks: 5, Hash: []byte{1, 2}}, []string{"AA"}}, + {&snapshot{Height: 2, Format: 1, Chunks: 3, Hash: []byte{1, 2}}, []string{"AA", "BB"}}, + {&snapshot{Height: 1, Format: 2, Chunks: 5, Hash: []byte{1, 2}}, []string{"AA", "BB"}}, + {&snapshot{Height: 1, Format: 1, Chunks: 4, Hash: []byte{1, 2}}, []string{"AA", "BB", "CC"}}, } // Add snapshots in reverse order, to make sure the pool enforces some order. for i := len(expectSnapshots) - 1; i >= 0; i-- { - for _, peerID := range expectSnapshots[i].peers { - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID(peerID)) - _, err := pool.Add(peer, expectSnapshots[i].snapshot) + for _, peerIDStr := range expectSnapshots[i].peers { + peerID, err := p2p.PeerIDFromString(peerIDStr) + require.NoError(t, err) + + _, err = pool.Add(peerID, expectSnapshots[i].snapshot) require.NoError(t, err) } } // Ranked should return the snapshots in the same order ranked := pool.Ranked() - assert.Len(t, ranked, len(expectSnapshots)) + require.Len(t, ranked, len(expectSnapshots)) + for i := range ranked { - assert.Equal(t, expectSnapshots[i].snapshot, ranked[i]) + require.Equal(t, expectSnapshots[i].snapshot, ranked[i]) } // Check that best snapshots are returned in expected order @@ -180,15 +180,16 @@ func TestSnapshotPool_Ranked_Best(t *testing.T) { require.Equal(t, snapshot, pool.Best()) pool.Reject(snapshot) } - assert.Nil(t, pool.Best()) + + require.Nil(t, pool.Best()) } func TestSnapshotPool_Reject(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) pool := newSnapshotPool(stateProvider) - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID("id")) + + peerID := p2p.PeerID{0xAA} snapshots := []*snapshot{ {Height: 2, Format: 2, Chunks: 1, Hash: []byte{1, 2}}, @@ -197,28 +198,28 @@ func TestSnapshotPool_Reject(t *testing.T) { {Height: 1, Format: 1, Chunks: 1, Hash: []byte{1, 2}}, } for _, s := range snapshots { - _, err := pool.Add(peer, s) + _, err := pool.Add(peerID, s) require.NoError(t, err) } pool.Reject(snapshots[0]) - assert.Equal(t, snapshots[1:], pool.Ranked()) + require.Equal(t, snapshots[1:], pool.Ranked()) - added, err := pool.Add(peer, snapshots[0]) + added, err := pool.Add(peerID, snapshots[0]) require.NoError(t, err) - assert.False(t, added) + require.False(t, added) - added, err = pool.Add(peer, &snapshot{Height: 3, Format: 3, Chunks: 1, Hash: []byte{1}}) + added, err = pool.Add(peerID, &snapshot{Height: 3, Format: 3, Chunks: 1, Hash: []byte{1}}) require.NoError(t, err) - assert.True(t, added) + require.True(t, added) } func TestSnapshotPool_RejectFormat(t *testing.T) { stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) pool := newSnapshotPool(stateProvider) - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID("id")) + + peerID := p2p.PeerID{0xAA} snapshots := []*snapshot{ {Height: 2, Format: 2, Chunks: 1, Hash: []byte{1, 2}}, @@ -227,21 +228,21 @@ func TestSnapshotPool_RejectFormat(t *testing.T) { {Height: 1, Format: 1, Chunks: 1, Hash: []byte{1, 2}}, } for _, s := range snapshots { - _, err := pool.Add(peer, s) + _, err := pool.Add(peerID, s) require.NoError(t, err) } pool.RejectFormat(1) - assert.Equal(t, []*snapshot{snapshots[0], snapshots[2]}, pool.Ranked()) + require.Equal(t, []*snapshot{snapshots[0], snapshots[2]}, pool.Ranked()) - added, err := pool.Add(peer, &snapshot{Height: 3, Format: 1, Chunks: 1, Hash: []byte{1}}) + added, err := pool.Add(peerID, &snapshot{Height: 3, Format: 1, Chunks: 1, Hash: []byte{1}}) require.NoError(t, err) - assert.False(t, added) - assert.Equal(t, []*snapshot{snapshots[0], snapshots[2]}, pool.Ranked()) + require.False(t, added) + require.Equal(t, []*snapshot{snapshots[0], snapshots[2]}, pool.Ranked()) - added, err = pool.Add(peer, &snapshot{Height: 3, Format: 3, Chunks: 1, Hash: []byte{1}}) + added, err = pool.Add(peerID, &snapshot{Height: 3, Format: 3, Chunks: 1, Hash: []byte{1}}) require.NoError(t, err) - assert.True(t, added) + require.True(t, added) } func TestSnapshotPool_RejectPeer(t *testing.T) { @@ -249,41 +250,41 @@ func TestSnapshotPool_RejectPeer(t *testing.T) { stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) pool := newSnapshotPool(stateProvider) - peerA := &p2pmocks.Peer{} - peerA.On("ID").Return(p2p.ID("a")) - peerB := &p2pmocks.Peer{} - peerB.On("ID").Return(p2p.ID("b")) + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} s1 := &snapshot{Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}} s2 := &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{2}} s3 := &snapshot{Height: 3, Format: 1, Chunks: 1, Hash: []byte{2}} - _, err := pool.Add(peerA, s1) + _, err := pool.Add(peerAID, s1) require.NoError(t, err) - _, err = pool.Add(peerA, s2) + + _, err = pool.Add(peerAID, s2) require.NoError(t, err) - _, err = pool.Add(peerB, s2) + _, err = pool.Add(peerBID, s2) require.NoError(t, err) - _, err = pool.Add(peerB, s3) + + _, err = pool.Add(peerBID, s3) require.NoError(t, err) - pool.RejectPeer(peerA.ID()) + pool.RejectPeer(peerAID) - assert.Empty(t, pool.GetPeers(s1)) + require.Empty(t, pool.GetPeers(s1)) peers2 := pool.GetPeers(s2) - assert.Len(t, peers2, 1) - assert.EqualValues(t, "b", peers2[0].ID()) + require.Len(t, peers2, 1) + require.Equal(t, peerBID, peers2[0]) peers3 := pool.GetPeers(s2) - assert.Len(t, peers3, 1) - assert.EqualValues(t, "b", peers3[0].ID()) + require.Len(t, peers3, 1) + require.Equal(t, peerBID, peers3[0]) // it should no longer be possible to add the peer back - _, err = pool.Add(peerA, s1) + _, err = pool.Add(peerAID, s1) require.NoError(t, err) - assert.Empty(t, pool.GetPeers(s1)) + require.Empty(t, pool.GetPeers(s1)) } func TestSnapshotPool_RemovePeer(t *testing.T) { @@ -291,35 +292,36 @@ func TestSnapshotPool_RemovePeer(t *testing.T) { stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) pool := newSnapshotPool(stateProvider) - peerA := &p2pmocks.Peer{} - peerA.On("ID").Return(p2p.ID("a")) - peerB := &p2pmocks.Peer{} - peerB.On("ID").Return(p2p.ID("b")) + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} s1 := &snapshot{Height: 1, Format: 1, Chunks: 1, Hash: []byte{1}} s2 := &snapshot{Height: 2, Format: 1, Chunks: 1, Hash: []byte{2}} - _, err := pool.Add(peerA, s1) + _, err := pool.Add(peerAID, s1) require.NoError(t, err) - _, err = pool.Add(peerA, s2) + + _, err = pool.Add(peerAID, s2) require.NoError(t, err) - _, err = pool.Add(peerB, s1) + + _, err = pool.Add(peerBID, s1) require.NoError(t, err) - pool.RemovePeer(peerA.ID()) + pool.RemovePeer(peerAID) peers1 := pool.GetPeers(s1) - assert.Len(t, peers1, 1) - assert.EqualValues(t, "b", peers1[0].ID()) + require.Len(t, peers1, 1) + require.Equal(t, peerBID, peers1[0]) peers2 := pool.GetPeers(s2) - assert.Empty(t, peers2) + require.Empty(t, peers2) // it should still be possible to add the peer back - _, err = pool.Add(peerA, s1) + _, err = pool.Add(peerAID, s1) require.NoError(t, err) + peers1 = pool.GetPeers(s1) - assert.Len(t, peers1, 2) - assert.EqualValues(t, "a", peers1[0].ID()) - assert.EqualValues(t, "b", peers1[1].ID()) + require.Len(t, peers1, 2) + require.Equal(t, peerAID, peers1[0]) + require.Equal(t, peerBID, peers1[1]) } diff --git a/statesync/syncer.go b/statesync/syncer.go index b4c3aa51f..7e09a4b7b 100644 --- a/statesync/syncer.go +++ b/statesync/syncer.go @@ -54,6 +54,8 @@ type syncer struct { conn proxy.AppConnSnapshot connQuery proxy.AppConnQuery snapshots *snapshotPool + snapshotCh chan<- p2p.Envelope + chunkCh chan<- p2p.Envelope tempDir string mtx tmsync.RWMutex @@ -61,14 +63,22 @@ type syncer struct { } // newSyncer creates a new syncer. -func newSyncer(logger log.Logger, conn proxy.AppConnSnapshot, connQuery proxy.AppConnQuery, - stateProvider StateProvider, tempDir string) *syncer { +func newSyncer( + logger log.Logger, + conn proxy.AppConnSnapshot, + connQuery proxy.AppConnQuery, + stateProvider StateProvider, + snapshotCh, chunkCh chan<- p2p.Envelope, + tempDir string, +) *syncer { return &syncer{ logger: logger, stateProvider: stateProvider, conn: conn, connQuery: connQuery, snapshots: newSnapshotPool(stateProvider), + snapshotCh: snapshotCh, + chunkCh: chunkCh, tempDir: tempDir, } } @@ -97,7 +107,7 @@ func (s *syncer) AddChunk(chunk *chunk) (bool, error) { // AddSnapshot adds a snapshot to the snapshot pool. It returns true if a new, previously unseen // snapshot was accepted and added. -func (s *syncer) AddSnapshot(peer p2p.Peer, snapshot *snapshot) (bool, error) { +func (s *syncer) AddSnapshot(peer p2p.PeerID, snapshot *snapshot) (bool, error) { added, err := s.snapshots.Add(peer, snapshot) if err != nil { return false, err @@ -109,17 +119,20 @@ func (s *syncer) AddSnapshot(peer p2p.Peer, snapshot *snapshot) (bool, error) { return added, nil } -// AddPeer adds a peer to the pool. For now we just keep it simple and send a single request -// to discover snapshots, later we may want to do retries and stuff. -func (s *syncer) AddPeer(peer p2p.Peer) { - s.logger.Debug("Requesting snapshots from peer", "peer", peer.ID()) - peer.Send(SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})) +// AddPeer adds a peer to the pool. For now we just keep it simple and send a +// single request to discover snapshots, later we may want to do retries and stuff. +func (s *syncer) AddPeer(peer p2p.PeerID) { + s.logger.Debug("Requesting snapshots from peer", "peer", peer.String()) + s.snapshotCh <- p2p.Envelope{ + To: peer, + Message: &ssproto.SnapshotsRequest{}, + } } // RemovePeer removes a peer from the pool. -func (s *syncer) RemovePeer(peer p2p.Peer) { - s.logger.Debug("Removing peer from sync", "peer", peer.ID()) - s.snapshots.RemovePeer(peer.ID()) +func (s *syncer) RemovePeer(peer p2p.PeerID) { + s.logger.Debug("Removing peer from sync", "peer", peer.String()) + s.snapshots.RemovePeer(peer) } // SyncAny tries to sync any of the snapshots in the snapshot pool, waiting to discover further @@ -192,8 +205,8 @@ func (s *syncer) SyncAny(discoveryTime time.Duration) (sm.State, *types.Commit, s.logger.Info("Snapshot senders rejected", "height", snapshot.Height, "format", snapshot.Format, "hash", fmt.Sprintf("%X", snapshot.Hash)) for _, peer := range s.snapshots.GetPeers(snapshot) { - s.snapshots.RejectPeer(peer.ID()) - s.logger.Info("Snapshot sender rejected", "peer", peer.ID()) + s.snapshots.RejectPeer(peer) + s.logger.Info("Snapshot sender rejected", "peer", peer.String()) } default: @@ -322,7 +335,7 @@ func (s *syncer) applyChunks(chunks *chunkQueue) error { resp, err := s.conn.ApplySnapshotChunkSync(context.Background(), abci.RequestApplySnapshotChunk{ Index: chunk.Index, Chunk: chunk.Chunk, - Sender: string(chunk.Sender), + Sender: chunk.Sender.String(), }) if err != nil { return fmt.Errorf("failed to apply chunk %v: %w", chunk.Index, err) @@ -341,9 +354,14 @@ func (s *syncer) applyChunks(chunks *chunkQueue) error { // Reject any senders as requested by the app for _, sender := range resp.RejectSenders { if sender != "" { - s.snapshots.RejectPeer(p2p.ID(sender)) - err := chunks.DiscardSender(p2p.ID(sender)) + peerID, err := p2p.PeerIDFromString(sender) if err != nil { + return err + } + + s.snapshots.RejectPeer(peerID) + + if err := chunks.DiscardSender(peerID); err != nil { return fmt.Errorf("failed to reject sender: %w", err) } } @@ -410,13 +428,23 @@ func (s *syncer) requestChunk(snapshot *snapshot, chunk uint32) { "format", snapshot.Format, "hash", snapshot.Hash) return } - s.logger.Debug("Requesting snapshot chunk", "height", snapshot.Height, - "format", snapshot.Format, "chunk", chunk, "peer", peer.ID()) - peer.Send(ChunkChannel, mustEncodeMsg(&ssproto.ChunkRequest{ - Height: snapshot.Height, - Format: snapshot.Format, - Index: chunk, - })) + + s.logger.Debug( + "Requesting snapshot chunk", + "height", snapshot.Height, + "format", snapshot.Format, + "chunk", chunk, + "peer", peer.String(), + ) + + s.chunkCh <- p2p.Envelope{ + To: peer, + Message: &ssproto.ChunkRequest{ + Height: snapshot.Height, + Format: snapshot.Format, + Index: chunk, + }, + } } // verifyApp verifies the sync, checking the app hash and last block height. It returns the @@ -426,18 +454,23 @@ func (s *syncer) verifyApp(snapshot *snapshot) (uint64, error) { if err != nil { return 0, fmt.Errorf("failed to query ABCI app for appHash: %w", err) } + if !bytes.Equal(snapshot.trustedAppHash, resp.LastBlockAppHash) { s.logger.Error("appHash verification failed", "expected", fmt.Sprintf("%X", snapshot.trustedAppHash), "actual", fmt.Sprintf("%X", resp.LastBlockAppHash)) return 0, errVerifyFailed } + if uint64(resp.LastBlockHeight) != snapshot.Height { - s.logger.Error("ABCI app reported unexpected last block height", - "expected", snapshot.Height, "actual", resp.LastBlockHeight) + s.logger.Error( + "ABCI app reported unexpected last block height", + "expected", snapshot.Height, + "actual", resp.LastBlockHeight, + ) return 0, errVerifyFailed } - s.logger.Info("Verified ABCI app", "height", snapshot.Height, - "appHash", fmt.Sprintf("%X", snapshot.trustedAppHash)) + + s.logger.Info("Verified ABCI app", "height", snapshot.Height, "appHash", fmt.Sprintf("%X", snapshot.trustedAppHash)) return resp.AppVersion, nil } diff --git a/statesync/syncer_test.go b/statesync/syncer_test.go index 96459c26e..229e91c98 100644 --- a/statesync/syncer_test.go +++ b/statesync/syncer_test.go @@ -6,15 +6,12 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" abci "github.com/tendermint/tendermint/abci/types" - "github.com/tendermint/tendermint/libs/log" tmsync "github.com/tendermint/tendermint/libs/sync" "github.com/tendermint/tendermint/p2p" - p2pmocks "github.com/tendermint/tendermint/p2p/mocks" tmstate "github.com/tendermint/tendermint/proto/tendermint/state" ssproto "github.com/tendermint/tendermint/proto/tendermint/statesync" tmversion "github.com/tendermint/tendermint/proto/tendermint/version" @@ -28,23 +25,6 @@ import ( var ctx = context.Background() -// Sets up a basic syncer that can be used to test OfferSnapshot requests -func setupOfferSyncer(t *testing.T) (*syncer, *proxymocks.AppConnSnapshot) { - connQuery := &proxymocks.AppConnQuery{} - connSnapshot := &proxymocks.AppConnSnapshot{} - stateProvider := &mocks.StateProvider{} - stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") - return syncer, connSnapshot -} - -// Sets up a simple peer mock with an ID -func simplePeer(id string) *p2pmocks.Peer { - peer := &p2pmocks.Peer{} - peer.On("ID").Return(p2p.ID(id)) - return peer -} - func TestSyncer_SyncAny(t *testing.T) { state := sm.State{ ChainID: "chain", @@ -53,7 +33,6 @@ func TestSyncer_SyncAny(t *testing.T) { Block: version.BlockProtocol, App: 0, }, - Software: version.TMCoreSemVer, }, @@ -87,38 +66,39 @@ func TestSyncer_SyncAny(t *testing.T) { connSnapshot := &proxymocks.AppConnSnapshot{} connQuery := &proxymocks.AppConnQuery{} - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} + + rts := setup(t, connSnapshot, connQuery, stateProvider, 3) // Adding a chunk should error when no sync is in progress - _, err := syncer.AddChunk(&chunk{Height: 1, Format: 1, Index: 0, Chunk: []byte{1}}) + _, err := rts.syncer.AddChunk(&chunk{Height: 1, Format: 1, Index: 0, Chunk: []byte{1}}) require.Error(t, err) // Adding a couple of peers should trigger snapshot discovery messages - peerA := &p2pmocks.Peer{} - peerA.On("ID").Return(p2p.ID("a")) - peerA.On("Send", SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})).Return(true) - syncer.AddPeer(peerA) - peerA.AssertExpectations(t) - - peerB := &p2pmocks.Peer{} - peerB.On("ID").Return(p2p.ID("b")) - peerB.On("Send", SnapshotChannel, mustEncodeMsg(&ssproto.SnapshotsRequest{})).Return(true) - syncer.AddPeer(peerB) - peerB.AssertExpectations(t) + rts.syncer.AddPeer(peerAID) + e := <-rts.snapshotOutCh + require.Equal(t, &ssproto.SnapshotsRequest{}, e.Message) + require.Equal(t, peerAID, e.To) + + rts.syncer.AddPeer(peerBID) + e = <-rts.snapshotOutCh + require.Equal(t, &ssproto.SnapshotsRequest{}, e.Message) + require.Equal(t, peerBID, e.To) // Both peers report back with snapshots. One of them also returns a snapshot we don't want, in // format 2, which will be rejected by the ABCI application. - new, err := syncer.AddSnapshot(peerA, s) + new, err := rts.syncer.AddSnapshot(peerAID, s) require.NoError(t, err) - assert.True(t, new) + require.True(t, new) - new, err = syncer.AddSnapshot(peerB, s) + new, err = rts.syncer.AddSnapshot(peerBID, s) require.NoError(t, err) - assert.False(t, new) + require.False(t, new) - new, err = syncer.AddSnapshot(peerB, &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1}}) + new, err = rts.syncer.AddSnapshot(peerBID, &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1}}) require.NoError(t, err) - assert.True(t, new) + require.True(t, new) // We start a sync, with peers sending back chunks when requested. We first reject the snapshot // with height 2 format 2, and accept the snapshot at height 1. @@ -144,24 +124,25 @@ func TestSyncer_SyncAny(t *testing.T) { chunkRequests := make(map[uint32]int) chunkRequestsMtx := tmsync.Mutex{} - onChunkRequest := func(args mock.Arguments) { - pb, err := decodeMsg(args[1].([]byte)) - assert.NoError(t, err) - msg := pb.(*ssproto.ChunkRequest) - assert.EqualValues(t, 1, msg.Height) - assert.EqualValues(t, 1, msg.Format) - assert.LessOrEqual(t, msg.Index, uint32(len(chunks))) - - added, err := syncer.AddChunk(chunks[msg.Index]) - assert.NoError(t, err) - assert.True(t, added) - - chunkRequestsMtx.Lock() - chunkRequests[msg.Index]++ - chunkRequestsMtx.Unlock() - } - peerA.On("Send", ChunkChannel, mock.Anything).Maybe().Run(onChunkRequest).Return(true) - peerB.On("Send", ChunkChannel, mock.Anything).Maybe().Run(onChunkRequest).Return(true) + + go func() { + for e := range rts.chunkOutCh { + msg, ok := e.Message.(*ssproto.ChunkRequest) + require.True(t, ok) + + require.EqualValues(t, 1, msg.Height) + require.EqualValues(t, 1, msg.Format) + require.LessOrEqual(t, msg.Index, uint32(len(chunks))) + + added, err := rts.syncer.AddChunk(chunks[msg.Index]) + require.NoError(t, err) + require.True(t, added) + + chunkRequestsMtx.Lock() + chunkRequests[msg.Index]++ + chunkRequestsMtx.Unlock() + } + }() // The first time we're applying chunk 2 we tell it to retry the snapshot and discard chunk 1, // which should cause it to keep the existing chunk 0 and 2, and restart restoration from @@ -189,113 +170,140 @@ func TestSyncer_SyncAny(t *testing.T) { LastBlockAppHash: []byte("app_hash"), }, nil) - newState, lastCommit, err := syncer.SyncAny(0) + newState, lastCommit, err := rts.syncer.SyncAny(0) require.NoError(t, err) time.Sleep(50 * time.Millisecond) // wait for peers to receive requests chunkRequestsMtx.Lock() - assert.Equal(t, map[uint32]int{0: 1, 1: 2, 2: 1}, chunkRequests) + require.Equal(t, map[uint32]int{0: 1, 1: 2, 2: 1}, chunkRequests) chunkRequestsMtx.Unlock() // The syncer should have updated the state app version from the ABCI info response. expectState := state expectState.Version.Consensus.App = 9 - assert.Equal(t, expectState, newState) - assert.Equal(t, commit, lastCommit) + require.Equal(t, expectState, newState) + require.Equal(t, commit, lastCommit) connSnapshot.AssertExpectations(t) connQuery.AssertExpectations(t) - peerA.AssertExpectations(t) - peerB.AssertExpectations(t) } func TestSyncer_SyncAny_noSnapshots(t *testing.T) { - syncer, _ := setupOfferSyncer(t) - _, _, err := syncer.SyncAny(0) - assert.Equal(t, errNoSnapshots, err) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) + + _, _, err := rts.syncer.SyncAny(0) + require.Equal(t, errNoSnapshots, err) } func TestSyncer_SyncAny_abort(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} - _, err := syncer.AddSnapshot(simplePeer("id"), s) + peerID := p2p.PeerID{0xAA} + + _, err := rts.syncer.AddSnapshot(peerID, s) require.NoError(t, err) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = syncer.SyncAny(0) - assert.Equal(t, errAbort, err) - connSnapshot.AssertExpectations(t) + _, _, err = rts.syncer.SyncAny(0) + require.Equal(t, errAbort, err) + rts.conn.AssertExpectations(t) } func TestSyncer_SyncAny_reject(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) // s22 is tried first, then s12, then s11, then errNoSnapshots s22 := &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} s12 := &snapshot{Height: 1, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} s11 := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} - _, err := syncer.AddSnapshot(simplePeer("id"), s22) + + peerID := p2p.PeerID{0xAA} + + _, err := rts.syncer.AddSnapshot(peerID, s22) require.NoError(t, err) - _, err = syncer.AddSnapshot(simplePeer("id"), s12) + + _, err = rts.syncer.AddSnapshot(peerID, s12) require.NoError(t, err) - _, err = syncer.AddSnapshot(simplePeer("id"), s11) + + _, err = rts.syncer.AddSnapshot(peerID, s11) require.NoError(t, err) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s22), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s12), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err = syncer.SyncAny(0) - assert.Equal(t, errNoSnapshots, err) - connSnapshot.AssertExpectations(t) + _, _, err = rts.syncer.SyncAny(0) + require.Equal(t, errNoSnapshots, err) + rts.conn.AssertExpectations(t) } func TestSyncer_SyncAny_reject_format(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) // s22 is tried first, which reject s22 and s12, then s11 will abort. s22 := &snapshot{Height: 2, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} s12 := &snapshot{Height: 1, Format: 2, Chunks: 3, Hash: []byte{1, 2, 3}} s11 := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} - _, err := syncer.AddSnapshot(simplePeer("id"), s22) + + peerID := p2p.PeerID{0xAA} + + _, err := rts.syncer.AddSnapshot(peerID, s22) require.NoError(t, err) - _, err = syncer.AddSnapshot(simplePeer("id"), s12) + + _, err = rts.syncer.AddSnapshot(peerID, s12) require.NoError(t, err) - _, err = syncer.AddSnapshot(simplePeer("id"), s11) + + _, err = rts.syncer.AddSnapshot(peerID, s11) require.NoError(t, err) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s22), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT_FORMAT}, nil) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s11), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_ABORT}, nil) - _, _, err = syncer.SyncAny(0) - assert.Equal(t, errAbort, err) - connSnapshot.AssertExpectations(t) + _, _, err = rts.syncer.SyncAny(0) + require.Equal(t, errAbort, err) + rts.conn.AssertExpectations(t) } func TestSyncer_SyncAny_reject_sender(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - peerA := simplePeer("a") - peerB := simplePeer("b") - peerC := simplePeer("c") + rts := setup(t, nil, nil, stateProvider, 2) + + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} + peerCID := p2p.PeerID{0xCC} // sbc will be offered first, which will be rejected with reject_sender, causing all snapshots // submitted by both b and c (i.e. sb, sc, sbc) to be rejected. Finally, sa will reject and @@ -304,44 +312,56 @@ func TestSyncer_SyncAny_reject_sender(t *testing.T) { sb := &snapshot{Height: 2, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} sc := &snapshot{Height: 3, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} sbc := &snapshot{Height: 4, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} - _, err := syncer.AddSnapshot(peerA, sa) + + _, err := rts.syncer.AddSnapshot(peerAID, sa) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerB, sb) + + _, err = rts.syncer.AddSnapshot(peerBID, sb) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerC, sc) + + _, err = rts.syncer.AddSnapshot(peerCID, sc) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerB, sbc) + + _, err = rts.syncer.AddSnapshot(peerBID, sbc) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerC, sbc) + + _, err = rts.syncer.AddSnapshot(peerCID, sbc) require.NoError(t, err) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(sbc), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT_SENDER}, nil) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(sa), AppHash: []byte("app_hash"), }).Once().Return(&abci.ResponseOfferSnapshot{Result: abci.ResponseOfferSnapshot_REJECT}, nil) - _, _, err = syncer.SyncAny(0) - assert.Equal(t, errNoSnapshots, err) - connSnapshot.AssertExpectations(t) + _, _, err = rts.syncer.SyncAny(0) + require.Equal(t, errNoSnapshots, err) + rts.conn.AssertExpectations(t) } func TestSyncer_SyncAny_abciError(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) errBoom := errors.New("boom") s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}} - _, err := syncer.AddSnapshot(simplePeer("id"), s) + + peerID := p2p.PeerID{0xAA} + + _, err := rts.syncer.AddSnapshot(peerID, s) require.NoError(t, err) - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Once().Return(nil, errBoom) - _, _, err = syncer.SyncAny(0) - assert.True(t, errors.Is(err, errBoom)) - connSnapshot.AssertExpectations(t) + _, _, err = rts.syncer.SyncAny(0) + require.True(t, errors.Is(err, errBoom)) + rts.conn.AssertExpectations(t) } func TestSyncer_offerSnapshot(t *testing.T) { @@ -365,13 +385,18 @@ func TestSyncer_offerSnapshot(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - syncer, connSnapshot := setupOfferSyncer(t) + stateProvider := &mocks.StateProvider{} + stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) + + rts := setup(t, nil, nil, stateProvider, 2) + s := &snapshot{Height: 1, Format: 1, Chunks: 3, Hash: []byte{1, 2, 3}, trustedAppHash: []byte("app_hash")} - connSnapshot.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ + rts.conn.On("OfferSnapshotSync", ctx, abci.RequestOfferSnapshot{ Snapshot: toABCI(s), AppHash: []byte("app_hash"), }).Return(&abci.ResponseOfferSnapshot{Result: tc.result}, tc.err) - err := syncer.offerSnapshot(s) + + err := rts.syncer.offerSnapshot(s) if tc.expectErr == unknownErr { require.Error(t, err) } else { @@ -379,7 +404,7 @@ func TestSyncer_offerSnapshot(t *testing.T) { if unwrapped != nil { err = unwrapped } - assert.Equal(t, tc.expectErr, err) + require.Equal(t, tc.expectErr, err) } }) } @@ -406,11 +431,10 @@ func TestSyncer_applyChunks_Results(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - connQuery := &proxymocks.AppConnQuery{} - connSnapshot := &proxymocks.AppConnSnapshot{} stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") + + rts := setup(t, nil, nil, stateProvider, 2) body := []byte{1, 2, 3} chunks, err := newChunkQueue(&snapshot{Height: 1, Format: 1, Chunks: 1}, "") @@ -418,17 +442,17 @@ func TestSyncer_applyChunks_Results(t *testing.T) { _, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 0, Chunk: body}) require.NoError(t, err) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: body, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: tc.result}, tc.err) if tc.result == abci.ResponseApplySnapshotChunk_RETRY { - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: body, }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) } - err = syncer.applyChunks(chunks) + err = rts.syncer.applyChunks(chunks) if tc.expectErr == unknownErr { require.Error(t, err) } else { @@ -436,9 +460,10 @@ func TestSyncer_applyChunks_Results(t *testing.T) { if unwrapped != nil { err = unwrapped } - assert.Equal(t, tc.expectErr, err) + require.Equal(t, tc.expectErr, err) } - connSnapshot.AssertExpectations(t) + + rts.conn.AssertExpectations(t) }) } } @@ -457,11 +482,10 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - connQuery := &proxymocks.AppConnQuery{} - connSnapshot := &proxymocks.AppConnSnapshot{} stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") + + rts := setup(t, nil, nil, stateProvider, 2) chunks, err := newChunkQueue(&snapshot{Height: 1, Format: 1, Chunks: 3}, "") require.NoError(t, err) @@ -476,13 +500,13 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { require.NoError(t, err) // The first two chunks are accepted, before the last one asks for 1 to be refetched - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ Index: 0, Chunk: []byte{0}, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ Index: 1, Chunk: []byte{1}, }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ Index: 2, Chunk: []byte{2}, }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: tc.result, @@ -493,15 +517,15 @@ func TestSyncer_applyChunks_RefetchChunks(t *testing.T) { // check the queue contents, and finally close the queue to end the goroutine. // We don't really care about the result of applyChunks, since it has separate test. go func() { - syncer.applyChunks(chunks) //nolint:errcheck // purposefully ignore error + rts.syncer.applyChunks(chunks) //nolint:errcheck // purposefully ignore error }() time.Sleep(50 * time.Millisecond) - assert.True(t, chunks.Has(0)) - assert.False(t, chunks.Has(1)) - assert.True(t, chunks.Has(2)) - err = chunks.Close() - require.NoError(t, err) + require.True(t, chunks.Has(0)) + require.False(t, chunks.Has(1)) + require.True(t, chunks.Has(2)) + + require.NoError(t, chunks.Close()) }) } } @@ -520,63 +544,71 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - connQuery := &proxymocks.AppConnQuery{} - connSnapshot := &proxymocks.AppConnSnapshot{} stateProvider := &mocks.StateProvider{} stateProvider.On("AppHash", mock.Anything, mock.Anything).Return([]byte("app_hash"), nil) - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") + + rts := setup(t, nil, nil, stateProvider, 2) // Set up three peers across two snapshots, and ask for one of them to be banned. // It should be banned from all snapshots. - peerA := simplePeer("a") - peerB := simplePeer("b") - peerC := simplePeer("c") + peerAID := p2p.PeerID{0xAA} + peerBID := p2p.PeerID{0xBB} + peerCID := p2p.PeerID{0xCC} s1 := &snapshot{Height: 1, Format: 1, Chunks: 3} s2 := &snapshot{Height: 2, Format: 1, Chunks: 3} - _, err := syncer.AddSnapshot(peerA, s1) + + _, err := rts.syncer.AddSnapshot(peerAID, s1) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerA, s2) + + _, err = rts.syncer.AddSnapshot(peerAID, s2) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerB, s1) + + _, err = rts.syncer.AddSnapshot(peerBID, s1) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerB, s2) + + _, err = rts.syncer.AddSnapshot(peerBID, s2) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerC, s1) + + _, err = rts.syncer.AddSnapshot(peerCID, s1) require.NoError(t, err) - _, err = syncer.AddSnapshot(peerC, s2) + + _, err = rts.syncer.AddSnapshot(peerCID, s2) require.NoError(t, err) chunks, err := newChunkQueue(s1, "") require.NoError(t, err) - added, err := chunks.Add(&chunk{Height: 1, Format: 1, Index: 0, Chunk: []byte{0}, Sender: peerA.ID()}) + + added, err := chunks.Add(&chunk{Height: 1, Format: 1, Index: 0, Chunk: []byte{0}, Sender: peerAID}) require.True(t, added) require.NoError(t, err) - added, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 1, Chunk: []byte{1}, Sender: peerB.ID()}) + + added, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 1, Chunk: []byte{1}, Sender: peerBID}) require.True(t, added) require.NoError(t, err) - added, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 2, Chunk: []byte{2}, Sender: peerC.ID()}) + + added, err = chunks.Add(&chunk{Height: 1, Format: 1, Index: 2, Chunk: []byte{2}, Sender: peerCID}) require.True(t, added) require.NoError(t, err) // The first two chunks are accepted, before the last one asks for b sender to be rejected - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ - Index: 0, Chunk: []byte{0}, Sender: "a", + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + Index: 0, Chunk: []byte{0}, Sender: "aa", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ - Index: 1, Chunk: []byte{1}, Sender: "b", + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + Index: 1, Chunk: []byte{1}, Sender: "bb", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ - Index: 2, Chunk: []byte{2}, Sender: "c", + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + Index: 2, Chunk: []byte{2}, Sender: "cc", }).Once().Return(&abci.ResponseApplySnapshotChunk{ Result: tc.result, - RejectSenders: []string{string(peerB.ID())}, + RejectSenders: []string{peerBID.String()}, }, nil) // On retry, the last chunk will be tried again, so we just accept it then. if tc.result == abci.ResponseApplySnapshotChunk_RETRY { - connSnapshot.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ - Index: 2, Chunk: []byte{2}, Sender: "c", + rts.conn.On("ApplySnapshotChunkSync", ctx, abci.RequestApplySnapshotChunk{ + Index: 2, Chunk: []byte{2}, Sender: "cc", }).Once().Return(&abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT}, nil) } @@ -584,23 +616,22 @@ func TestSyncer_applyChunks_RejectSenders(t *testing.T) { // However, it will block on e.g. retry result, so we spawn a goroutine that will // be shut down when the chunk queue closes. go func() { - syncer.applyChunks(chunks) //nolint:errcheck // purposefully ignore error + rts.syncer.applyChunks(chunks) //nolint:errcheck // purposefully ignore error }() time.Sleep(50 * time.Millisecond) - s1peers := syncer.snapshots.GetPeers(s1) - assert.Len(t, s1peers, 2) - assert.EqualValues(t, "a", s1peers[0].ID()) - assert.EqualValues(t, "c", s1peers[1].ID()) + s1peers := rts.syncer.snapshots.GetPeers(s1) + require.Len(t, s1peers, 2) + require.EqualValues(t, "aa", s1peers[0].String()) + require.EqualValues(t, "cc", s1peers[1].String()) - syncer.snapshots.GetPeers(s1) - assert.Len(t, s1peers, 2) - assert.EqualValues(t, "a", s1peers[0].ID()) - assert.EqualValues(t, "c", s1peers[1].ID()) + rts.syncer.snapshots.GetPeers(s1) + require.Len(t, s1peers, 2) + require.EqualValues(t, "aa", s1peers[0].String()) + require.EqualValues(t, "cc", s1peers[1].String()) - err = chunks.Close() - require.NoError(t, err) + require.NoError(t, chunks.Close()) }) } } @@ -634,20 +665,18 @@ func TestSyncer_verifyApp(t *testing.T) { for name, tc := range testcases { tc := tc t.Run(name, func(t *testing.T) { - connQuery := &proxymocks.AppConnQuery{} - connSnapshot := &proxymocks.AppConnSnapshot{} - stateProvider := &mocks.StateProvider{} - syncer := newSyncer(log.NewNopLogger(), connSnapshot, connQuery, stateProvider, "") + rts := setup(t, nil, nil, nil, 2) - connQuery.On("InfoSync", ctx, proxy.RequestInfo).Return(tc.response, tc.err) - version, err := syncer.verifyApp(s) + rts.connQuery.On("InfoSync", ctx, proxy.RequestInfo).Return(tc.response, tc.err) + version, err := rts.syncer.verifyApp(s) unwrapped := errors.Unwrap(err) if unwrapped != nil { err = unwrapped } - assert.Equal(t, tc.expectErr, err) + + require.Equal(t, tc.expectErr, err) if err == nil { - assert.Equal(t, tc.response.AppVersion, version) + require.Equal(t, tc.response.AppVersion, version) } }) } diff --git a/test/maverick/node/node.go b/test/maverick/node/node.go index b9ae86143..a4c10273a 100644 --- a/test/maverick/node/node.go +++ b/test/maverick/node/node.go @@ -527,7 +527,7 @@ func createSwitch(config *cfg.Config, peerFilters []p2p.PeerFilterFunc, mempoolReactor *mempl.Reactor, bcReactor p2p.Reactor, - stateSyncReactor *statesync.Reactor, + stateSyncReactor *p2p.ReactorShim, consensusReactor *cs.Reactor, evidenceReactor *evidence.Reactor, nodeInfo p2p.NodeInfo, @@ -790,9 +790,18 @@ func NewNode(config *cfg.Config, // FIXME The way we do phased startups (e.g. replay -> fast sync -> consensus) is very messy, // we should clean this whole thing up. See: // https://github.com/tendermint/tendermint/issues/4644 - stateSyncReactor := statesync.NewReactor(proxyApp.Snapshot(), proxyApp.Query(), - config.StateSync.TempDir) - stateSyncReactor.SetLogger(logger.With("module", "statesync")) + stateSyncReactorShim := p2p.NewReactorShim("StateSyncShim", statesync.ChannelShims) + stateSyncReactorShim.SetLogger(logger.With("module", "statesync")) + + stateSyncReactor := statesync.NewReactor( + stateSyncReactorShim.Logger, + proxyApp.Snapshot(), + proxyApp.Query(), + stateSyncReactorShim.GetChannel(statesync.SnapshotChannel), + stateSyncReactorShim.GetChannel(statesync.ChunkChannel), + stateSyncReactorShim.PeerUpdates, + config.StateSync.TempDir, + ) nodeInfo, err := makeNodeInfo(config, nodeKey, txIndexer, genDoc, state) if err != nil { @@ -806,7 +815,7 @@ func NewNode(config *cfg.Config, p2pLogger := logger.With("module", "p2p") sw := createSwitch( config, transport, p2pMetrics, peerFilters, mempoolReactor, bcReactor, - stateSyncReactor, consensusReactor, evidenceReactor, nodeInfo, nodeKey, p2pLogger, + stateSyncReactorShim, consensusReactor, evidenceReactor, nodeInfo, nodeKey, p2pLogger, ) err = sw.AddPersistentPeers(splitAndTrimEmpty(config.P2P.PersistentPeers, ",", " ")) @@ -936,6 +945,11 @@ func (n *Node) OnStart() error { return err } + // Start the real state sync reactor separately since the switch uses the shim. + if err := n.stateSyncReactor.Start(); err != nil { + return err + } + // Always connect to persistent peers err = n.sw.DialPeersAsync(splitAndTrimEmpty(n.config.P2P.PersistentPeers, ",", " ")) if err != nil { @@ -977,6 +991,11 @@ func (n *Node) OnStop() { n.Logger.Error("Error closing switch", "err", err) } + // Stop the real state sync reactor separately since the switch uses the shim. + if err := n.stateSyncReactor.Stop(); err != nil { + n.Logger.Error("failed to stop state sync service", "err", err) + } + // stop mempool WAL if n.config.Mempool.WalEnabled() { n.mempool.CloseWAL() @@ -1297,7 +1316,7 @@ func makeNodeInfo( cs.StateChannel, cs.DataChannel, cs.VoteChannel, cs.VoteSetBitsChannel, mempl.MempoolChannel, evidence.EvidenceChannel, - statesync.SnapshotChannel, statesync.ChunkChannel, + byte(statesync.SnapshotChannel), byte(statesync.ChunkChannel), }, Moniker: config.Moniker, Other: p2p.DefaultNodeInfoOther{