diff --git a/cmd/tendermint/commands/rollback.go b/cmd/tendermint/commands/rollback.go index 5aff232be..c19d35cce 100644 --- a/cmd/tendermint/commands/rollback.go +++ b/cmd/tendermint/commands/rollback.go @@ -40,6 +40,10 @@ func RollbackState(config *cfg.Config) (int64, []byte, error) { if err != nil { return -1, nil, err } + defer func() { + _ = blockStore.Close() + _ = stateStore.Close() + }() // rollback the last state return state.Rollback(blockStore, stateStore) diff --git a/cmd/tendermint/commands/rollback_test.go b/cmd/tendermint/commands/rollback_test.go new file mode 100644 index 000000000..f842ddd0e --- /dev/null +++ b/cmd/tendermint/commands/rollback_test.go @@ -0,0 +1,71 @@ +package commands_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/tendermint/tendermint/cmd/tendermint/commands" + "github.com/tendermint/tendermint/rpc/client/local" + rpctest "github.com/tendermint/tendermint/rpc/test" + e2e "github.com/tendermint/tendermint/test/e2e/app" +) + +func TestRollbackIntegration(t *testing.T) { + var height int64 + dir := t.TempDir() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cfg, err := rpctest.CreateConfig(t.Name()) + require.NoError(t, err) + cfg.BaseConfig.DBBackend = "goleveldb" + app, err := e2e.NewApplication(e2e.DefaultConfig(dir)) + + t.Run("First run", func(t *testing.T) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + require.NoError(t, err) + node, _, err := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) + require.NoError(t, err) + + time.Sleep(3 * time.Second) + cancel() + node.Wait() + require.False(t, node.IsRunning()) + }) + + t.Run("Rollback", func(t *testing.T) { + require.NoError(t, app.Rollback()) + height, _, err = commands.RollbackState(cfg) + require.NoError(t, err) + + }) + + t.Run("Restart", func(t *testing.T) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + node2, _, err2 := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) + require.NoError(t, err2) + + client, err := local.New(node2.(local.NodeService)) + require.NoError(t, err) + + ticker := time.NewTicker(200 * time.Millisecond) + for { + select { + case <-ctx.Done(): + t.Fatalf("failed to make progress after 20 seconds. Min height: %d", height) + case <-ticker.C: + status, err := client.Status(ctx) + require.NoError(t, err) + + if status.SyncInfo.LatestBlockHeight > height { + return + } + } + } + }) + +} diff --git a/internal/evidence/pool.go b/internal/evidence/pool.go index f342dec4c..d99ff2d54 100644 --- a/internal/evidence/pool.go +++ b/internal/evidence/pool.go @@ -261,6 +261,10 @@ func (evpool *Pool) State() sm.State { return evpool.state } +func (evpool *Pool) Close() error { + return evpool.evidenceStore.Close() +} + // IsExpired checks whether evidence or a polc is expired by checking whether a height and time is older // than set by the evidence consensus parameters func (evpool *Pool) isExpired(height int64, time time.Time) bool { diff --git a/internal/evidence/reactor.go b/internal/evidence/reactor.go index 4e37e1d17..89f60b749 100644 --- a/internal/evidence/reactor.go +++ b/internal/evidence/reactor.go @@ -111,6 +111,9 @@ func (r *Reactor) OnStop() { // panics will occur. <-r.evidenceCh.Done() <-r.peerUpdates.Done() + + // Close the evidence db + r.evpool.Close() } // handleEvidenceMessage handles envelopes sent from peers on the EvidenceChannel. diff --git a/internal/state/mocks/store.go b/internal/state/mocks/store.go index 4452f9bec..02c69d3e0 100644 --- a/internal/state/mocks/store.go +++ b/internal/state/mocks/store.go @@ -29,6 +29,20 @@ func (_m *Store) Bootstrap(_a0 state.State) error { return r0 } +// Close provides a mock function with given fields: +func (_m *Store) Close() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + // Load provides a mock function with given fields: func (_m *Store) Load() (state.State, error) { ret := _m.Called() diff --git a/internal/state/rollback.go b/internal/state/rollback.go index e78957b02..ea0eff4de 100644 --- a/internal/state/rollback.go +++ b/internal/state/rollback.go @@ -36,18 +36,18 @@ func Rollback(bs BlockStore, ss Store) (int64, []byte, error) { } // state store height is equal to blockstore height. We're good to proceed with rolling back state - rollbackHeight := invalidState.LastBlockHeight + rollbackHeight := invalidState.LastBlockHeight - 1 rollbackBlock := bs.LoadBlockMeta(rollbackHeight) if rollbackBlock == nil { return -1, nil, fmt.Errorf("block at height %d not found", rollbackHeight) } - previousValidatorSet, err := ss.LoadValidators(rollbackHeight - 1) + previousLastValidatorSet, err := ss.LoadValidators(rollbackHeight) if err != nil { return -1, nil, err } - previousParams, err := ss.LoadConsensusParams(rollbackHeight) + previousParams, err := ss.LoadConsensusParams(rollbackHeight + 1) if err != nil { return -1, nil, err } @@ -55,13 +55,13 @@ func Rollback(bs BlockStore, ss Store) (int64, []byte, error) { valChangeHeight := invalidState.LastHeightValidatorsChanged // this can only happen if the validator set changed since the last block if valChangeHeight > rollbackHeight { - valChangeHeight = rollbackHeight + valChangeHeight = rollbackHeight + 1 } paramsChangeHeight := invalidState.LastHeightConsensusParamsChanged // this can only happen if params changed from the last block if paramsChangeHeight > rollbackHeight { - paramsChangeHeight = rollbackHeight + paramsChangeHeight = rollbackHeight + 1 } // build the new state from the old state and the prior block @@ -77,13 +77,13 @@ func Rollback(bs BlockStore, ss Store) (int64, []byte, error) { ChainID: invalidState.ChainID, InitialHeight: invalidState.InitialHeight, - LastBlockHeight: invalidState.LastBlockHeight - 1, - LastBlockID: rollbackBlock.Header.LastBlockID, + LastBlockHeight: rollbackBlock.Header.Height, + LastBlockID: rollbackBlock.BlockID, LastBlockTime: rollbackBlock.Header.Time, NextValidators: invalidState.Validators, Validators: invalidState.LastValidators, - LastValidators: previousValidatorSet, + LastValidators: previousLastValidatorSet, LastHeightValidatorsChanged: valChangeHeight, ConsensusParams: previousParams, diff --git a/internal/state/rollback_test.go b/internal/state/rollback_test.go index e782b4d89..fb5ca9796 100644 --- a/internal/state/rollback_test.go +++ b/internal/state/rollback_test.go @@ -15,50 +15,49 @@ import ( func TestRollback(t *testing.T) { var ( - height int64 = 100 - appVersion uint64 = 10 + height int64 = 100 + nextHeight int64 = 101 ) blockStore := &mocks.BlockStore{} stateStore := setupStateStore(t, height) initialState, err := stateStore.Load() require.NoError(t, err) - height++ - block := &types.BlockMeta{ - Header: types.Header{ - Height: height, - AppHash: initialState.AppHash, - LastBlockID: initialState.LastBlockID, - LastResultsHash: initialState.LastResultsHash, - }, - } - blockStore.On("LoadBlockMeta", height).Return(block) - blockStore.On("Height").Return(height) - // perform the rollback over a version bump - appVersion++ newParams := types.DefaultConsensusParams() - newParams.Version.AppVersion = appVersion + newParams.Version.AppVersion = 11 newParams.Block.MaxBytes = 1000 nextState := initialState.Copy() - nextState.LastBlockHeight = height - nextState.Version.Consensus.App = appVersion + nextState.LastBlockHeight = nextHeight + nextState.Version.Consensus.App = 11 nextState.LastBlockID = factory.MakeBlockID() nextState.AppHash = factory.RandomHash() nextState.LastValidators = initialState.Validators nextState.Validators = initialState.NextValidators nextState.NextValidators = initialState.NextValidators.CopyIncrementProposerPriority(1) nextState.ConsensusParams = *newParams - nextState.LastHeightConsensusParamsChanged = height + 1 - nextState.LastHeightValidatorsChanged = height + 1 + nextState.LastHeightConsensusParamsChanged = nextHeight + 1 + nextState.LastHeightValidatorsChanged = nextHeight + 1 // update the state require.NoError(t, stateStore.Save(nextState)) + block := &types.BlockMeta{ + BlockID: initialState.LastBlockID, + Header: types.Header{ + Height: initialState.LastBlockHeight, + AppHash: initialState.AppHash, + LastBlockID: factory.MakeBlockID(), + LastResultsHash: initialState.LastResultsHash, + }, + } + blockStore.On("LoadBlockMeta", initialState.LastBlockHeight).Return(block) + blockStore.On("Height").Return(nextHeight) + // rollback the state rollbackHeight, rollbackHash, err := state.Rollback(blockStore, stateStore) require.NoError(t, err) - require.EqualValues(t, int64(100), rollbackHeight) + require.EqualValues(t, height, rollbackHeight) require.EqualValues(t, initialState.AppHash, rollbackHash) blockStore.AssertExpectations(t) @@ -82,11 +81,11 @@ func TestRollbackNoBlocks(t *testing.T) { stateStore := setupStateStore(t, height) blockStore := &mocks.BlockStore{} blockStore.On("Height").Return(height) - blockStore.On("LoadBlockMeta", height).Return(nil) + blockStore.On("LoadBlockMeta", height-1).Return(nil) _, _, err := state.Rollback(blockStore, stateStore) require.Error(t, err) - require.Contains(t, err.Error(), "block at height 100 not found") + require.Contains(t, err.Error(), "block at height 99 not found") } func TestRollbackDifferentStateHeight(t *testing.T) { diff --git a/internal/state/store.go b/internal/state/store.go index 0f1d2b444..de17be0d7 100644 --- a/internal/state/store.go +++ b/internal/state/store.go @@ -92,6 +92,8 @@ type Store interface { Bootstrap(State) error // PruneStates takes the height from which to prune up to (exclusive) PruneStates(int64) error + // Close closes the connection with the database + Close() error } // dbStore wraps a db (github.com/tendermint/tm-db) @@ -658,3 +660,7 @@ func (store dbStore) saveConsensusParamsInfo( return batch.Set(consensusParamsKey(nextHeight), bz) } + +func (store dbStore) Close() error { + return store.db.Close() +} diff --git a/internal/store/store.go b/internal/store/store.go index 6cdcdf719..88a33a585 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -572,6 +572,10 @@ func (bs *BlockStore) SaveSignedHeader(sh *types.SignedHeader, blockID types.Blo return batch.Close() } +func (bs *BlockStore) Close() error { + return bs.db.Close() +} + //---------------------------------- KEY ENCODING ----------------------------------------- // key prefixes diff --git a/node/node.go b/node/node.go index 86ae173e5..b9cdea4e6 100644 --- a/node/node.go +++ b/node/node.go @@ -704,6 +704,16 @@ func (n *nodeImpl) OnStop() { n.Logger.Error("problem shutting down additional services", "err", err) } } + if n.blockStore != nil { + if err := n.blockStore.Close(); err != nil { + n.Logger.Error("problem closing blockstore", "err", err) + } + } + if n.stateStore != nil { + if err := n.stateStore.Close(); err != nil { + n.Logger.Error("problem closing statestore", "err", err) + } + } } func (n *nodeImpl) startRPC(ctx context.Context) ([]net.Listener, error) { diff --git a/node/setup.go b/node/setup.go index 55ca592e0..ca47a9c25 100644 --- a/node/setup.go +++ b/node/setup.go @@ -77,7 +77,7 @@ func initDBs( blockStoreDB, err := dbProvider(&config.DBContext{ID: "blockstore", Config: cfg}) if err != nil { - return nil, nil, func() error { return nil }, err + return nil, nil, func() error { return nil }, fmt.Errorf("unable to initialize blockstore: %w", err) } closers := []closer{} blockStore := store.NewBlockStore(blockStoreDB) @@ -85,7 +85,7 @@ func initDBs( stateDB, err := dbProvider(&config.DBContext{ID: "state", Config: cfg}) if err != nil { - return nil, nil, makeCloser(closers), err + return nil, nil, makeCloser(closers), fmt.Errorf("unable to initialize statestore: %w", err) } closers = append(closers, stateDB.Close) @@ -243,7 +243,7 @@ func createEvidenceReactor( ) (*evidence.Reactor, *evidence.Pool, error) { evidenceDB, err := dbProvider(&config.DBContext{ID: "evidence", Config: cfg}) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("unable to initialize evidence db: %w", err) } logger = logger.With("module", "evidence") @@ -432,7 +432,7 @@ func createPeerManager( peerDB, err := dbProvider(&config.DBContext{ID: "peerstore", Config: cfg}) if err != nil { - return nil, func() error { return nil }, err + return nil, func() error { return nil }, fmt.Errorf("unable to initialize peer store: %w", err) } peerManager, err := p2p.NewPeerManager(nodeID, peerDB, options) diff --git a/test/e2e/app/app.go b/test/e2e/app/app.go index 353bb6b75..5a782fa33 100644 --- a/test/e2e/app/app.go +++ b/test/e2e/app/app.go @@ -80,7 +80,7 @@ func DefaultConfig(dir string) *Config { // NewApplication creates the application. func NewApplication(cfg *Config) (*Application, error) { - state, err := NewState(filepath.Join(cfg.Dir, "state.json"), cfg.PersistInterval) + state, err := NewState(cfg.Dir, cfg.PersistInterval) if err != nil { return nil, err } @@ -267,6 +267,10 @@ func (app *Application) ApplySnapshotChunk(req abci.RequestApplySnapshotChunk) a return abci.ResponseApplySnapshotChunk{Result: abci.ResponseApplySnapshotChunk_ACCEPT} } +func (app *Application) Rollback() error { + return app.state.Rollback() +} + // validatorUpdates generates a validator set update. func (app *Application) validatorUpdates(height uint64) (abci.ValidatorUpdates, error) { updates := app.cfg.ValidatorUpdates[fmt.Sprintf("%v", height)] diff --git a/test/e2e/app/state.go b/test/e2e/app/state.go index 7376b8776..e82a22539 100644 --- a/test/e2e/app/state.go +++ b/test/e2e/app/state.go @@ -7,10 +7,14 @@ import ( "errors" "fmt" "os" + "path/filepath" "sort" "sync" ) +const stateFileName = "app_state.json" +const prevStateFileName = "prev_app_state.json" + // State is the application state. type State struct { sync.RWMutex @@ -19,16 +23,19 @@ type State struct { Hash []byte // private fields aren't marshaled to disk. - file string + currentFile string + // app saves current and previous state for rollback functionality + previousFile string persistInterval uint64 initialHeight uint64 } // NewState creates a new state. -func NewState(file string, persistInterval uint64) (*State, error) { +func NewState(dir string, persistInterval uint64) (*State, error) { state := &State{ Values: make(map[string]string), - file: file, + currentFile: filepath.Join(dir, stateFileName), + previousFile: filepath.Join(dir, prevStateFileName), persistInterval: persistInterval, } state.Hash = hashItems(state.Values) @@ -44,13 +51,22 @@ func NewState(file string, persistInterval uint64) (*State, error) { // load loads state from disk. It does not take out a lock, since it is called // during construction. func (s *State) load() error { - bz, err := os.ReadFile(s.file) + bz, err := os.ReadFile(s.currentFile) if err != nil { - return fmt.Errorf("failed to read state from %q: %w", s.file, err) + // if the current state doesn't exist then we try recover from the previous state + if errors.Is(err, os.ErrNotExist) { + bz, err = os.ReadFile(s.previousFile) + if err != nil { + return fmt.Errorf("failed to read both current and previous state (%q): %w", + s.previousFile, err) + } + } else { + return fmt.Errorf("failed to read state from %q: %w", s.currentFile, err) + } } err = json.Unmarshal(bz, s) if err != nil { - return fmt.Errorf("invalid state data in %q: %w", s.file, err) + return fmt.Errorf("invalid state data in %q: %w", s.currentFile, err) } return nil } @@ -64,12 +80,19 @@ func (s *State) save() error { } // We write the state to a separate file and move it to the destination, to // make it atomic. - newFile := fmt.Sprintf("%v.new", s.file) + newFile := fmt.Sprintf("%v.new", s.currentFile) err = os.WriteFile(newFile, bz, 0644) if err != nil { - return fmt.Errorf("failed to write state to %q: %w", s.file, err) + return fmt.Errorf("failed to write state to %q: %w", s.currentFile, err) + } + // We take the current state and move it to the previous state, replacing it + if _, err := os.Stat(s.currentFile); err == nil { + if err := os.Rename(s.currentFile, s.previousFile); err != nil { + return fmt.Errorf("failed to replace previous state: %w", err) + } } - return os.Rename(newFile, s.file) + // Finally, we take the new state and replace the current state. + return os.Rename(newFile, s.currentFile) } // Export exports key/value pairs as JSON, used for state sync snapshots. @@ -135,6 +158,18 @@ func (s *State) Commit() (uint64, []byte, error) { return s.Height, s.Hash, nil } +func (s *State) Rollback() error { + bz, err := os.ReadFile(s.previousFile) + if err != nil { + return fmt.Errorf("failed to read state from %q: %w", s.previousFile, err) + } + err = json.Unmarshal(bz, s) + if err != nil { + return fmt.Errorf("invalid state data in %q: %w", s.previousFile, err) + } + return nil +} + // hashItems hashes a set of key/value items. func hashItems(items map[string]string) []byte { keys := make([]string, 0, len(items))