diff --git a/light/client.go b/light/client.go index f741f1bee..5c5b206ce 100644 --- a/light/client.go +++ b/light/client.go @@ -630,7 +630,7 @@ func (c *Client) sequence( } else { // intermediate headers interimHeader, interimVals, err = c.signedHeaderAndValSetFromPrimary(height) if err != nil { - return err + return ErrVerificationFailed{From: trustedHeader.Height, To: height, Reason: err} } } @@ -644,11 +644,18 @@ func (c *Client) sequence( err = VerifyAdjacent(c.chainID, trustedHeader, interimHeader, interimVals, c.trustingPeriod, now, c.maxClockDrift) if err != nil { - err := fmt.Errorf("verify adjacent from #%d to #%d failed: %w", - trustedHeader.Height, interimHeader.Height, err) + err := ErrVerificationFailed{From: trustedHeader.Height, To: interimHeader.Height, Reason: err} switch errors.Unwrap(err).(type) { case ErrInvalidHeader: + // If the target header is invalid, return immediately. + if err.To == newHeader.Height { + c.logger.Debug("Target header is invalid", "err", err) + return err + } + + // If some intermediate header is invalid, replace the primary and try + // again. c.logger.Error("primary sent invalid header -> replacing", "err", err) replaceErr := c.replacePrimaryProvider() if replaceErr != nil { @@ -656,8 +663,28 @@ func (c *Client) sequence( // return original error return err } + + replacementHeader, replacementVals, fErr := c.signedHeaderAndValSetFromPrimary(newHeader.Height) + if fErr != nil { + c.logger.Error("Can't fetch header/vals from primary", "err", fErr) + // return original error + return err + } + + if !bytes.Equal(replacementHeader.Hash(), newHeader.Hash()) || + !bytes.Equal(replacementVals.Hash(), newVals.Hash()) { + c.logger.Error("Replacement provider has a different header/vals", + "newHash", newHeader.Hash(), + "newVals", newVals.Hash(), + "replHash", replacementHeader.Hash(), + "replVals", replacementVals.Hash()) + // return original error + return err + } + // attempt to verify header again height-- + continue default: return err @@ -728,15 +755,14 @@ func (c *Client) bisection( Height)*bisectionNumerator/bisectionDenominator interimHeader, interimVals, err := c.signedHeaderAndValSetFrom(pivotHeight, source) if err != nil { - return err + return ErrVerificationFailed{From: trustedHeader.Height, To: pivotHeight, Reason: err} } headerCache = append(headerCache, headerSet{interimHeader, interimVals}) } depth++ default: - return fmt.Errorf("verify non adjacent from #%d to #%d failed: %w", - trustedHeader.Height, headerCache[depth].sh.Height, err) + return ErrVerificationFailed{From: trustedHeader.Height, To: headerCache[depth].sh.Height, Reason: err} } } } @@ -755,6 +781,15 @@ func (c *Client) bisectionAgainstPrimary( switch errors.Unwrap(err).(type) { case ErrInvalidHeader: + // If the target header is invalid, return immediately. + invalidHeaderHeight := err.(ErrVerificationFailed).To + if invalidHeaderHeight == newHeader.Height { + c.logger.Debug("Target header is invalid", "err", err) + return err + } + + // If some intermediate header is invalid, replace the primary and try + // again. c.logger.Error("primary sent invalid header -> replacing", "err", err) replaceErr := c.replacePrimaryProvider() if replaceErr != nil { diff --git a/light/client_test.go b/light/client_test.go index a97281a6e..5e0c520bf 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -982,7 +982,7 @@ func TestClientRemovesWitnessIfItSendsUsIncorrectHeader(t *testing.T) { assert.EqualValues(t, 1, len(c.Witnesses())) } -func TestClientTrustedValidatorSet(t *testing.T) { +func TestClient_TrustedValidatorSet(t *testing.T) { noValSetNode := mockp.New( chainID, headerSet, @@ -997,7 +997,15 @@ func TestClientTrustedValidatorSet(t *testing.T) { badValSetNode := mockp.New( chainID, - headerSet, + map[int64]*types.SignedHeader{ + 1: h1, + // 3/3 signed, but validator set at height 2 below is invalid -> witness + // should be removed. + 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash"), hash("results_hash"), + 0, len(keys), types.BlockID{Hash: h1.Hash()}), + 3: h3, + }, map[int64]*types.ValidatorSet{ 1: vals, 2: differentVals, @@ -1009,19 +1017,16 @@ func TestClientTrustedValidatorSet(t *testing.T) { chainID, trustOptions, noValSetNode, - []provider.Provider{badValSetNode, fullNode, fullNode}, + []provider.Provider{fullNode, badValSetNode, fullNode}, dbs.New(dbm.NewMemDB(), chainID), light.Logger(log.TestingLogger()), ) require.NoError(t, err) assert.Equal(t, 2, len(c.Witnesses())) - _, err = c.VerifyHeaderAtHeight(2, bTime.Add(2*time.Hour).Add(1*time.Second)) - assert.Error(t, err) - assert.Equal(t, 1, len(c.Witnesses())) - _, err = c.VerifyHeaderAtHeight(2, bTime.Add(2*time.Hour).Add(1*time.Second)) assert.NoError(t, err) + assert.Equal(t, 1, len(c.Witnesses())) valSet, height, err := c.TrustedValidatorSet(0) assert.NoError(t, err) diff --git a/light/errors.go b/light/errors.go index a592d1c60..3846b00b3 100644 --- a/light/errors.go +++ b/light/errors.go @@ -56,6 +56,25 @@ func (e ErrConflictingHeaders) Error() string { e.H2.Hash(), e.Witness) } +// ErrVerificationFailed means either sequential or skipping verification has +// failed to verify from header #1 to header #2 due to some reason. +type ErrVerificationFailed struct { + From int64 + To int64 + Reason error +} + +// Unwrap returns underlying reason. +func (e ErrVerificationFailed) Unwrap() error { + return e.Reason +} + +func (e ErrVerificationFailed) Error() string { + return fmt.Sprintf( + "verify from #%d to #%d failed: %v", + e.From, e.To, e.Reason) +} + // errNoWitnesses means that there are not enough witnesses connected to // continue running the light client. type errNoWitnesses struct{}