diff --git a/lite2/client.go b/lite2/client.go index 61559a694..65040c133 100644 --- a/lite2/client.go +++ b/lite2/client.go @@ -703,13 +703,14 @@ func (c *Client) sequence( newHeader *types.SignedHeader, newVals *types.ValidatorSet, now time.Time) error { + // 1) Verify any intermediate headers. var ( interimHeader *types.SignedHeader interimNextVals *types.ValidatorSet err error ) - for height := trustedHeader.Height + 1; height < newHeader.Height; height++ { + for height := trustedHeader.Height + 1; height <= newHeader.Height; height++ { interimHeader, err = c.signedHeaderFromPrimary(height) if err != nil { return errors.Wrapf(err, "failed to obtain the header #%d", height) @@ -734,16 +735,17 @@ func (c *Client) sequence( if err != nil { return errors.Wrapf(err, "failed to obtain the vals #%d", height+1) } - } - err = c.updateTrustedHeaderAndNextVals(interimHeader, interimNextVals) - if err != nil { - return errors.Wrapf(err, "failed to update trusted state #%d", height) + if !bytes.Equal(interimHeader.NextValidatorsHash, interimNextVals.Hash()) { + return errors.Errorf("expected next validator's hash %X, but got %X (height #%d)", + interimHeader.NextValidatorsHash, + interimNextVals.Hash(), + interimHeader.Height) + } } trustedHeader, trustedNextVals = interimHeader, interimNextVals } - // 2) Verify the new header. - return VerifyAdjacent(c.chainID, c.latestTrustedHeader, newHeader, newVals, c.trustingPeriod, now) + return nil } // see VerifyHeader @@ -757,7 +759,7 @@ func (c *Client) bisection( interimVals := newVals interimHeader := newHeader - for trustedHeader.Height < newHeader.Height { + for { c.logger.Debug("Verify newHeader against trustedHeader", "trustedHeight", trustedHeader.Height, "trustedHash", hash2str(trustedHeader.Hash()), @@ -767,24 +769,22 @@ func (c *Client) bisection( c.trustLevel) switch err.(type) { case nil: + if interimHeader.Height == newHeader.Height { + return nil + } + // Update the lower bound to the previous upper bound - trustedHeader = interimHeader - trustedNextVals, err = c.validatorSetFromPrimary(interimHeader.Height + 1) + interimNextVals, err := c.validatorSetFromPrimary(interimHeader.Height + 1) if err != nil { return err } - if !bytes.Equal(trustedHeader.NextValidatorsHash, trustedNextVals.Hash()) { + if !bytes.Equal(interimHeader.NextValidatorsHash, interimNextVals.Hash()) { return errors.Errorf("expected next validator's hash %X, but got %X (height #%d)", - trustedHeader.NextValidatorsHash, - trustedNextVals.Hash(), - trustedHeader.Height) - } - - err = c.updateTrustedHeaderAndNextVals(trustedHeader, trustedNextVals) - if err != nil { - return err + interimHeader.NextValidatorsHash, + interimNextVals.Hash(), + interimHeader.Height) } - + trustedHeader, trustedNextVals = interimHeader, interimNextVals // Update the upper bound to the untrustedHeader interimHeader, interimVals = newHeader, newVals @@ -799,8 +799,6 @@ func (c *Client) bisection( return errors.Wrapf(err, "failed to verify the header #%d", newHeader.Height) } } - - return nil } // persist header and next validators to trustedStore. @@ -841,6 +839,10 @@ func (c *Client) backwards(trustedHeader *types.SignedHeader, newHeader *types.S err error ) + if HeaderExpired(newHeader, c.trustingPeriod, now) { + return ErrOldHeaderExpired{newHeader.Time.Add(c.trustingPeriod), now} + } + for trustedHeader.Height > newHeader.Height { interimHeader, err = c.signedHeaderFromPrimary(trustedHeader.Height - 1) if err != nil { @@ -857,10 +859,6 @@ func (c *Client) backwards(trustedHeader *types.SignedHeader, newHeader *types.S trustedHeader.Time) } - if HeaderExpired(interimHeader, c.trustingPeriod, now) { - return ErrOldHeaderExpired{interimHeader.Time.Add(c.trustingPeriod), now} - } - if !bytes.Equal(interimHeader.Hash(), trustedHeader.LastBlockID.Hash) { return errors.Errorf("older header hash %X does not match trusted header's last block %X", interimHeader.Hash(), diff --git a/lite2/client_test.go b/lite2/client_test.go index a6ffc31ab..a6b615227 100644 --- a/lite2/client_test.go +++ b/lite2/client_test.go @@ -57,6 +57,9 @@ var ( ) func TestClient_SequentialVerification(t *testing.T) { + newKeys := genPrivKeys(4) + newVals := newKeys.ToValidators(10, 1) + testCases := []struct { name string otherHeaders map[int64]*types.SignedHeader // all except ^ @@ -138,6 +141,25 @@ func TestClient_SequentialVerification(t *testing.T) { false, true, }, + { + "bad: different validator set at height 3", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (3/3 signed) + 2: h2, + // last header (3/3 signed) + 3: h3, + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + 4: newVals, + }, + false, + true, + }, } for _, tc := range testCases { @@ -368,7 +390,8 @@ func TestClient_Cleanup(t *testing.T) { require.NoError(t, err) c.Stop() - c.Cleanup() + err = c.Cleanup() + require.NoError(t, err) // Check no headers exist after Cleanup. h, err := c.TrustedHeader(1, bTime.Add(1*time.Second)) @@ -747,9 +770,6 @@ func TestClient_BackwardsVerification(t *testing.T) { Logger(log.TestingLogger()), ) require.NoError(t, err) - err = c.Start() - require.NoError(t, err) - defer c.Stop() // 1) header is missing => expect no error h, err := c.VerifyHeaderAtHeight(2, bTime.Add(1*time.Hour).Add(1*time.Second)) @@ -762,6 +782,15 @@ func TestClient_BackwardsVerification(t *testing.T) { h, err = c.VerifyHeaderAtHeight(1, bTime.Add(1*time.Hour).Add(1*time.Second)) assert.Error(t, err) assert.NotNil(t, h) + + // 3) already stored headers should return the header without error + h, err = c.VerifyHeaderAtHeight(3, bTime.Add(1*time.Hour).Add(1*time.Second)) + assert.NoError(t, err) + assert.NotNil(t, h) + + // 4) cannot verify a header in the future + _, err = c.VerifyHeaderAtHeight(4, bTime.Add(1*time.Hour).Add(1*time.Second)) + assert.Error(t, err) } func TestClient_NewClientFromTrustedStore(t *testing.T) {