diff --git a/lite2/client.go b/lite2/client.go index 2c487f9c9..61559a694 100644 --- a/lite2/client.go +++ b/lite2/client.go @@ -121,9 +121,9 @@ type Client struct { // Where trusted headers are stored. trustedStore store.Store // Highest trusted header from the store (height=H). - trustedHeader *types.SignedHeader + latestTrustedHeader *types.SignedHeader // Highest next validator set from the store (height=H+1). - trustedNextVals *types.ValidatorSet + latestTrustedNextVals *types.ValidatorSet // See UpdatePeriod option updatePeriod time.Duration @@ -164,13 +164,13 @@ func NewClient( return nil, err } - if c.trustedHeader != nil { + if c.latestTrustedHeader != nil { if err := c.checkTrustedHeaderUsingOptions(trustOptions); err != nil { return nil, err } } - if c.trustedHeader == nil || c.trustedHeader.Height < trustOptions.Height { + if c.latestTrustedHeader == nil || c.latestTrustedHeader.Height < trustOptions.Height { if err := c.initializeWithTrustOptions(trustOptions); err != nil { return nil, err } @@ -253,8 +253,8 @@ func (c *Client) restoreTrustedHeaderAndNextVals() error { return errors.Wrap(err, "can't get last trusted next validators") } - c.trustedHeader = trustedHeader - c.trustedNextVals = trustedNextVals + c.latestTrustedHeader = trustedHeader + c.latestTrustedNextVals = trustedNextVals c.logger.Debug("Restored trusted header and next vals", lastHeight) } @@ -283,24 +283,24 @@ func (c *Client) restoreTrustedHeaderAndNextVals() error { func (c *Client) checkTrustedHeaderUsingOptions(options TrustOptions) error { var primaryHash []byte switch { - case options.Height > c.trustedHeader.Height: - h, err := c.signedHeaderFromPrimary(c.trustedHeader.Height) + case options.Height > c.latestTrustedHeader.Height: + h, err := c.signedHeaderFromPrimary(c.latestTrustedHeader.Height) if err != nil { return err } primaryHash = h.Hash() - case options.Height == c.trustedHeader.Height: + case options.Height == c.latestTrustedHeader.Height: primaryHash = options.Hash - case options.Height < c.trustedHeader.Height: + case options.Height < c.latestTrustedHeader.Height: c.logger.Info("Client initialized with old header (trusted is more recent)", "old", options.Height, - "trustedHeight", c.trustedHeader.Height, - "trustedHash", hash2str(c.trustedHeader.Hash())) + "trustedHeight", c.latestTrustedHeader.Height, + "trustedHash", hash2str(c.latestTrustedHeader.Hash())) action := fmt.Sprintf( "Rollback to %d (%X)? Note this will remove newer headers up to %d (%X)", options.Height, options.Hash, - c.trustedHeader.Height, c.trustedHeader.Hash()) + c.latestTrustedHeader.Height, c.latestTrustedHeader.Hash()) if c.confirmationFn(action) { // remove all the headers (options.Height, trustedHeader.Height] c.cleanup(options.Height + 1) @@ -314,13 +314,13 @@ func (c *Client) checkTrustedHeaderUsingOptions(options TrustOptions) error { primaryHash = options.Hash } - if !bytes.Equal(primaryHash, c.trustedHeader.Hash()) { + if !bytes.Equal(primaryHash, c.latestTrustedHeader.Hash()) { c.logger.Info("Prev. trusted header's hash (h1) doesn't match hash from primary provider (h2)", - "h1", c.trustedHeader.Hash(), "h1", primaryHash) + "h1", hash2str(c.latestTrustedHeader.Hash()), "h2", hash2str(primaryHash)) action := fmt.Sprintf( "Prev. trusted header's hash %X doesn't match hash %X from primary provider. Remove all the stored headers?", - c.trustedHeader.Hash(), primaryHash) + c.latestTrustedHeader.Hash(), primaryHash) if c.confirmationFn(action) { err := c.Cleanup() if err != nil { @@ -449,14 +449,7 @@ func (c *Client) TrustedHeader(height int64, now time.Time) (*types.SignedHeader // 2) Get header from store. h, err := c.trustedStore.SignedHeader(height) - switch { - case errors.Is(err, store.ErrSignedHeaderNotFound): - // 2.1) If not found, try to fetch header from primary. - h, err = c.fetchMissingTrustedHeader(height, now) - if err != nil { - return nil, err - } - case err != nil: + if err != nil { return nil, err } @@ -523,15 +516,25 @@ func (c *Client) ChainID() string { return c.chainID } -// VerifyHeaderAtHeight fetches the header and validators at the given height -// and calls VerifyHeader. +// VerifyHeaderAtHeight fetches header and validators at the given height +// and calls VerifyHeader. It returns header immediately if such exists in +// trustedStore (no verification is needed). // -// If the trusted header is more recent than one here, an error is returned. -// If the header is not found by the primary provider, -// provider.ErrSignedHeaderNotFound error is returned. +// It returns provider.ErrSignedHeaderNotFound if header is not found by +// primary. +// It returns ErrOldHeaderExpired if header expired. func (c *Client) VerifyHeaderAtHeight(height int64, now time.Time) (*types.SignedHeader, error) { - if c.trustedHeader.Height >= height { - return nil, errors.Errorf("header at more recent height #%d exists", c.trustedHeader.Height) + if height <= 0 { + return nil, errors.New("negative or zero height") + } + + h, err := c.TrustedHeader(height, now) + switch err.(type) { + case nil: // Return already trusted header + c.logger.Info("Header has already been verified", "height", height, "hash", hash2str(h.Hash())) + return h, nil + case ErrOldHeaderExpired: + return nil, err } // Request the header and the vals. @@ -540,10 +543,12 @@ func (c *Client) VerifyHeaderAtHeight(height int64, now time.Time) (*types.Signe return nil, err } - return newHeader, c.VerifyHeader(newHeader, newVals, now) + return newHeader, c.verifyHeader(newHeader, newVals, now) } -// VerifyHeader verifies new header against the trusted state. +// VerifyHeader verifies new header against the trusted state. It returns +// immediately if newHeader exists in trustedStore (no verification is +// needed). // // SequentialVerification: verifies that 2/3 of the trusted validator set has // signed the new header. If the headers are not adjacent, **all** intermediate @@ -555,7 +560,7 @@ func (c *Client) VerifyHeaderAtHeight(height int64, now time.Time) (*types.Signe // intermediate headers will be requested. See the specification for details. // https://github.com/tendermint/spec/blob/master/spec/consensus/light-client.md // -// If the trusted header is more recent than one here, an error is returned. +// It returns ErrOldHeaderExpired if newHeader expired. // // If, at any moment, SignedHeader or ValidatorSet are not found by the primary // provider, provider.ErrSignedHeaderNotFound / @@ -565,21 +570,49 @@ func (c *Client) VerifyHeaderAtHeight(height int64, now time.Time) (*types.Signe // validator set at height newHeader.Height+1 (i.e. // newHeader.NextValidatorsHash). func (c *Client) VerifyHeader(newHeader *types.SignedHeader, newVals *types.ValidatorSet, now time.Time) error { + h, err := c.TrustedHeader(newHeader.Height, now) + switch err.(type) { + case nil: // Return already trusted header + // Make sure it's the same header. + if !bytes.Equal(h.Hash(), newHeader.Hash()) { + return errors.Errorf("existing trusted header %X does not match newHeader %X", h.Hash(), newHeader.Hash()) + } + c.logger.Info("Header has already been verified", + "height", newHeader.Height, "hash", hash2str(newHeader.Hash())) + return nil + case ErrOldHeaderExpired: + return err + } + + return c.verifyHeader(newHeader, newVals, now) +} + +func (c *Client) verifyHeader(newHeader *types.SignedHeader, newVals *types.ValidatorSet, now time.Time) error { c.logger.Info("VerifyHeader", "height", newHeader.Height, "hash", hash2str(newHeader.Hash()), "vals", hash2str(newVals.Hash())) - if c.trustedHeader.Height >= newHeader.Height { - return errors.Errorf("header at more recent height #%d exists", c.trustedHeader.Height) - } - var err error - switch c.verificationMode { - case sequential: - err = c.sequence(c.trustedHeader, c.trustedNextVals, newHeader, newVals, now) - case skipping: - err = c.bisection(c.trustedHeader, c.trustedNextVals, newHeader, newVals, now) - default: - panic(fmt.Sprintf("Unknown verification mode: %b", c.verificationMode)) + + // 1) If going forward, perform either bisection or sequential verification + if newHeader.Height >= c.latestTrustedHeader.Height { + switch c.verificationMode { + case sequential: + err = c.sequence(c.latestTrustedHeader, c.latestTrustedNextVals, newHeader, newVals, now) + case skipping: + err = c.bisection(c.latestTrustedHeader, c.latestTrustedNextVals, newHeader, newVals, now) + default: + panic(fmt.Sprintf("Unknown verification mode: %b", c.verificationMode)) + } + } else { + // 2) Otherwise, perform backwards verification + // Find the closest trusted header after newHeader.Height + var closestHeader *types.SignedHeader + closestHeader, err = c.trustedStore.SignedHeaderAfter(newHeader.Height) + if err != nil { + return errors.Wrapf(err, "can't get signed header after height %d", newHeader.Height) + } + + err = c.backwards(closestHeader, newHeader, now) } if err != nil { c.logger.Error("Can't verify", "err", err) @@ -596,6 +629,7 @@ func (c *Client) VerifyHeader(newHeader *types.SignedHeader, newVals *types.Vali if err != nil { return err } + return c.updateTrustedHeaderAndNextVals(newHeader, nextVals) } @@ -652,8 +686,8 @@ func (c *Client) cleanup(stopHeight int64) error { } } - c.trustedHeader = nil - c.trustedNextVals = nil + c.latestTrustedHeader = nil + c.latestTrustedNextVals = nil err = c.restoreTrustedHeaderAndNextVals() if err != nil { return err @@ -682,8 +716,8 @@ func (c *Client) sequence( } c.logger.Debug("Verify newHeader against trustedHeader", - "trustedHeight", c.trustedHeader.Height, - "trustedHash", hash2str(c.trustedHeader.Hash()), + "trustedHeight", c.latestTrustedHeader.Height, + "trustedHash", hash2str(c.latestTrustedHeader.Hash()), "newHeight", interimHeader.Height, "newHash", hash2str(interimHeader.Hash())) err = VerifyAdjacent(c.chainID, trustedHeader, interimHeader, trustedNextVals, @@ -709,7 +743,7 @@ func (c *Client) sequence( } // 2) Verify the new header. - return VerifyAdjacent(c.chainID, c.trustedHeader, newHeader, newVals, c.trustingPeriod, now) + return VerifyAdjacent(c.chainID, c.latestTrustedHeader, newHeader, newVals, c.trustingPeriod, now) } // see VerifyHeader @@ -779,8 +813,8 @@ func (c *Client) updateTrustedHeaderAndNextVals(h *types.SignedHeader, nextVals return errors.Wrap(err, "failed to save trusted header") } - c.trustedHeader = h - c.trustedNextVals = nextVals + c.latestTrustedHeader = h + c.latestTrustedNextVals = nextVals return nil } @@ -799,77 +833,44 @@ func (c *Client) fetchHeaderAndValsAtHeight(height int64) (*types.SignedHeader, return h, vals, nil } -// fetchMissingTrustedHeader finds the closest height after the -// requested height and does backwards verification. -func (c *Client) fetchMissingTrustedHeader(height int64, now time.Time) (*types.SignedHeader, error) { - c.logger.Info("Fetching missing header", "height", height) - - closestHeader, err := c.trustedStore.SignedHeaderAfter(height) - if err != nil { - return nil, errors.Wrapf(err, "can't get signed header after %d", height) - } - - // Perform backwards verification from closestHeader to header at the given - // height. - h, err := c.backwards(height, closestHeader, now) - if err != nil { - return nil, err - } - - // Fetch next validator set from primary and persist it. - nextVals, err := c.validatorSetFromPrimary(height + 1) - if err != nil { - return nil, errors.Wrapf(err, "failed to obtain the vals #%d", height) - } - if !bytes.Equal(h.NextValidatorsHash, nextVals.Hash()) { - return nil, errors.Errorf("expected next validator's hash %X, but got %X", - h.NextValidatorsHash, nextVals.Hash()) - } - if err := c.trustedStore.SaveSignedHeaderAndNextValidatorSet(h, nextVals); err != nil { - return nil, errors.Wrap(err, "failed to save trusted header") - } - - return h, nil -} - // Backwards verification (see VerifyHeaderBackwards func in the spec) -func (c *Client) backwards(toHeight int64, fromHeader *types.SignedHeader, now time.Time) (*types.SignedHeader, error) { +func (c *Client) backwards(trustedHeader *types.SignedHeader, newHeader *types.SignedHeader, + now time.Time) error { var ( - trustedHeader = fromHeader - untrustedHeader *types.SignedHeader - err error + interimHeader *types.SignedHeader + err error ) - for i := trustedHeader.Height - 1; i >= toHeight; i-- { - untrustedHeader, err = c.signedHeaderFromPrimary(i) + for trustedHeader.Height > newHeader.Height { + interimHeader, err = c.signedHeaderFromPrimary(trustedHeader.Height - 1) if err != nil { - return nil, errors.Wrapf(err, "failed to obtain the header #%d", i) + return errors.Wrapf(err, "failed to obtain the header at height #%d", trustedHeader.Height-1) } - if err := untrustedHeader.ValidateBasic(c.chainID); err != nil { - return nil, errors.Wrap(err, "untrustedHeader.ValidateBasic failed") + if err := interimHeader.ValidateBasic(c.chainID); err != nil { + return errors.Wrap(err, "untrustedHeader.ValidateBasic failed") } - if !untrustedHeader.Time.Before(trustedHeader.Time) { - return nil, errors.Errorf("expected older header time %v to be before newer header time %v", - untrustedHeader.Time, + if !interimHeader.Time.Before(trustedHeader.Time) { + return errors.Errorf("expected older header time %v to be before newer header time %v", + interimHeader.Time, trustedHeader.Time) } - if HeaderExpired(untrustedHeader, c.trustingPeriod, now) { - return nil, ErrOldHeaderExpired{untrustedHeader.Time.Add(c.trustingPeriod), now} + if HeaderExpired(interimHeader, c.trustingPeriod, now) { + return ErrOldHeaderExpired{interimHeader.Time.Add(c.trustingPeriod), now} } - if !bytes.Equal(untrustedHeader.Hash(), trustedHeader.LastBlockID.Hash) { - return nil, errors.Errorf("older header hash %X does not match trusted header's last block %X", - untrustedHeader.Hash(), + if !bytes.Equal(interimHeader.Hash(), trustedHeader.LastBlockID.Hash) { + return errors.Errorf("older header hash %X does not match trusted header's last block %X", + interimHeader.Hash(), trustedHeader.LastBlockID.Hash) } - trustedHeader = untrustedHeader + trustedHeader = interimHeader } - return trustedHeader, nil + return nil } // compare header with all witnesses provided. @@ -899,7 +900,7 @@ func (c *Client) compareNewHeaderWithWitnesses(h *types.SignedHeader) error { } if !bytes.Equal(h.Hash(), altH.Hash()) { - if err = c.trustedNextVals.VerifyCommitTrusting(c.chainID, altH.Commit.BlockID, + if err = c.latestTrustedNextVals.VerifyCommitTrusting(c.chainID, altH.Commit.BlockID, altH.Height, altH.Commit, c.trustLevel); err != nil { c.logger.Error("Witness sent us incorrect header", "err", err, "witness", witness) witnessesToRemove = append(witnessesToRemove, i) @@ -989,7 +990,7 @@ func (c *Client) RemoveNoLongerTrustedHeaders(now time.Time) { // 3) Remove all headers that are outside of the trusting period. // // NOTE: even the latest header can be removed. it's okay because - // c.trustedHeader will retain it in memory so other funcs like VerifyHeader + // c.latestTrustedHeader will retain it in memory so other funcs like VerifyHeader // don't crash. for height := oldestHeight; height <= latestHeight; height++ { h, err := c.trustedStore.SignedHeader(height) diff --git a/lite2/client_test.go b/lite2/client_test.go index d8f44adc6..a6ffc31ab 100644 --- a/lite2/client_test.go +++ b/lite2/client_test.go @@ -732,7 +732,7 @@ func TestClientReplacesPrimaryWithWitnessIfPrimaryIsUnavailable(t *testing.T) { assert.Equal(t, 1, len(c.Witnesses())) } -func TestClient_TrustedHeaderFetchesMissingHeader(t *testing.T) { +func TestClient_BackwardsVerification(t *testing.T) { c, err := NewClient( chainID, TrustOptions{ @@ -752,16 +752,16 @@ func TestClient_TrustedHeaderFetchesMissingHeader(t *testing.T) { defer c.Stop() // 1) header is missing => expect no error - h, err := c.TrustedHeader(2, bTime.Add(1*time.Hour).Add(1*time.Second)) + h, err := c.VerifyHeaderAtHeight(2, bTime.Add(1*time.Hour).Add(1*time.Second)) require.NoError(t, err) if assert.NotNil(t, h) { assert.EqualValues(t, 2, h.Height) } // 2) header is missing, but it's expired => expect error - h, err = c.TrustedHeader(1, bTime.Add(1*time.Hour).Add(1*time.Second)) + h, err = c.VerifyHeaderAtHeight(1, bTime.Add(1*time.Hour).Add(1*time.Second)) assert.Error(t, err) - assert.Nil(t, h) + assert.NotNil(t, h) } func TestClient_NewClientFromTrustedStore(t *testing.T) {