diff --git a/light/client_benchmark_test.go b/light/client_benchmark_test.go index 1097fa233..38b44d271 100644 --- a/light/client_benchmark_test.go +++ b/light/client_benchmark_test.go @@ -64,7 +64,7 @@ func BenchmarkSequence(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) @@ -101,7 +101,7 @@ func BenchmarkBisection(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) genesisBlock, _ := benchmarkFullNode.LightBlock(ctx, 1) @@ -137,7 +137,7 @@ func BenchmarkBackwards(b *testing.B) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1000, 100, 1, bTime) + headers, vals, _ := genLightBlocksWithKeys(b, chainID, 1000, 100, 1, bTime) benchmarkFullNode := newProviderBenchmarkImpl(headers, vals) trustedBlock, _ := benchmarkFullNode.LightBlock(ctx, 0) diff --git a/light/client_test.go b/light/client_test.go index 5fc2c3a12..09abe6d9e 100644 --- a/light/client_test.go +++ b/light/client_test.go @@ -26,1092 +26,1092 @@ const ( chainID = "test" ) -var ( - keys = genPrivKeys(4) - vals = keys.ToValidators(20, 10) - bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - h1 = keys.GenSignedHeader(chainID, 1, bTime, nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) - // 3/3 signed - h2 = keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h1.Hash()}) - // 3/3 signed - h3 = keys.GenSignedHeaderLastBlockID(chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h2.Hash()}) - trustPeriod = 4 * time.Hour - trustOptions = light.TrustOptions{ - Period: 4 * time.Hour, - Height: 1, - Hash: h1.Hash(), - } - valSet = map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: vals, - 4: vals, - } - headerSet = map[int64]*types.SignedHeader{ - 1: h1, - // interim header (3/3 signed) - 2: h2, - // last header (3/3 signed) - 3: h3, +var bTime time.Time + +func init() { + var err error + bTime, err = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") + if err != nil { + panic(err) } - l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} - l2 = &types.LightBlock{SignedHeader: h2, ValidatorSet: vals} - l3 = &types.LightBlock{SignedHeader: h3, ValidatorSet: vals} -) +} -func TestValidateTrustOptions(t *testing.T) { - testCases := []struct { - err bool - to light.TrustOptions - }{ - { - false, - trustOptions, - }, - { - true, - light.TrustOptions{ - Period: -1 * time.Hour, - Height: 1, - Hash: h1.Hash(), +func TestClient(t *testing.T) { + var ( + keys = genPrivKeys(4) + vals = keys.ToValidators(20, 10) + trustPeriod = 4 * time.Hour + + valSet = map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: vals, + 4: vals, + } + + h1 = keys.GenSignedHeader(t, chainID, 1, bTime, nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) + // 3/3 signed + h2 = keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h1.Hash()}) + // 3/3 signed + h3 = keys.GenSignedHeaderLastBlockID(t, chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: h2.Hash()}) + trustOptions = light.TrustOptions{ + Period: 4 * time.Hour, + Height: 1, + Hash: h1.Hash(), + } + headerSet = map[int64]*types.SignedHeader{ + 1: h1, + // interim header (3/3 signed) + 2: h2, + // last header (3/3 signed) + 3: h3, + } + l1 = &types.LightBlock{SignedHeader: h1, ValidatorSet: vals} + l2 = &types.LightBlock{SignedHeader: h2, ValidatorSet: vals} + l3 = &types.LightBlock{SignedHeader: h3, ValidatorSet: vals} + ) + t.Run("ValidateTrustOptions", func(t *testing.T) { + testCases := []struct { + err bool + to light.TrustOptions + }{ + { + false, + trustOptions, }, - }, - { - true, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 0, - Hash: h1.Hash(), + { + true, + light.TrustOptions{ + Period: -1 * time.Hour, + Height: 1, + Hash: h1.Hash(), + }, }, - }, - { - true, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 1, - Hash: []byte("incorrect hash"), + { + true, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 0, + Hash: h1.Hash(), + }, + }, + { + true, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 1, + Hash: []byte("incorrect hash"), + }, }, - }, - } - - for _, tc := range testCases { - err := tc.to.ValidateBasic() - if tc.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) } - } -} + for idx, tc := range testCases { + t.Run(fmt.Sprint(idx), func(t *testing.T) { + err := tc.to.ValidateBasic() + if tc.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } + }) + t.Run("SequentialVerification", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func TestClient_SequentialVerification(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - newKeys := genPrivKeys(4) - newVals := newKeys.ToValidators(10, 1) - differentVals, _ := factory.RandValidatorSet(ctx, t, 10, 100) - - testCases := []struct { - name string - otherHeaders map[int64]*types.SignedHeader // all except ^ - vals map[int64]*types.ValidatorSet - initErr bool - verifyErr bool - }{ - { - "good", - headerSet, - valSet, - false, - false, - }, - { - "bad: different first header", - map[int64]*types.SignedHeader{ - // different header - 1: keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + newKeys := genPrivKeys(4) + newVals := newKeys.ToValidators(10, 1) + differentVals, _ := factory.RandValidatorSet(ctx, t, 10, 100) + + testCases := []struct { + name string + otherHeaders map[int64]*types.SignedHeader // all except ^ + vals map[int64]*types.ValidatorSet + initErr bool + verifyErr bool + }{ + { + name: "good", + otherHeaders: headerSet, + vals: valSet, + initErr: false, + verifyErr: false, }, - map[int64]*types.ValidatorSet{ - 1: vals, + { + "bad: different first header", + map[int64]*types.SignedHeader{ + // different header + 1: keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + }, + true, + false, }, - true, - false, - }, - { - "bad: no first signed header", - map[int64]*types.SignedHeader{}, - map[int64]*types.ValidatorSet{ - 1: differentVals, + { + "bad: no first signed header", + map[int64]*types.SignedHeader{}, + map[int64]*types.ValidatorSet{ + 1: differentVals, + }, + true, + true, }, - true, - true, - }, - { - "bad: different first validator set", - map[int64]*types.SignedHeader{ - 1: h1, + { + "bad: different first validator set", + map[int64]*types.SignedHeader{ + 1: h1, + }, + map[int64]*types.ValidatorSet{ + 1: differentVals, + }, + true, + true, }, - map[int64]*types.ValidatorSet{ - 1: differentVals, + { + "bad: 1/3 signed interim header", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (1/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), + // last header (3/3 signed) + 3: keys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + }, + valSet, + false, + true, }, - true, - true, - }, - { - "bad: 1/3 signed interim header", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (1/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), - // last header (3/3 signed) - 3: keys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + { + "bad: 1/3 signed last header", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (3/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + // last header (1/3 signed) + 3: keys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), + }, + valSet, + false, + true, }, - valSet, - false, - true, - }, - { - "bad: 1/3 signed last header", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (3/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - // last header (1/3 signed) - 3: keys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), + { + "bad: different validator set at height 3", + headerSet, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + true, }, - valSet, - false, - true, - }, - { - "bad: different validator set at height 3", - headerSet, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, + } + + for _, tc := range testCases { + testCase := tc + t.Run(testCase.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + logger := log.NewTestingLogger(t) + + mockNode := mockNodeFromHeadersAndVals(testCase.otherHeaders, testCase.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.SequentialVerification(), + light.Logger(logger), + ) + + if testCase.initErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) + if testCase.verifyErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + mockNode.AssertExpectations(t) + }) + } + + }) + t.Run("SkippingVerification", func(t *testing.T) { + // required for 2nd test case + newKeys := genPrivKeys(4) + newVals := newKeys.ToValidators(10, 1) + + // 1/3+ of vals, 2/3- of newVals + transitKeys := keys.Extend(3) + transitVals := transitKeys.ToValidators(10, 1) + + testCases := []struct { + name string + otherHeaders map[int64]*types.SignedHeader // all except ^ + vals map[int64]*types.ValidatorSet + initErr bool + verifyErr bool + }{ + { + "good", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // last header (3/3 signed) + 3: h3, + }, + valSet, + false, + false, }, - false, - true, - }, - } + { + "good, but val set changes by 2/3 (1/3 of vals is still present)", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + 3: transitKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, transitVals, transitVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(transitKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: transitVals, + }, + false, + false, + }, + { + "good, but val set changes 100% at height 2", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // interim header (3/3 signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), + // last header (0/4 of the original val set signed) + 3: newKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + false, + }, + { + "bad: last header signed by newVals, interim header has no signers", + map[int64]*types.SignedHeader{ + // trusted header + 1: h1, + // last header (0/4 of the original val set signed) + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, 0), + // last header (0/4 of the original val set signed) + 3: newKeys.GenSignedHeader(t, chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), + }, + map[int64]*types.ValidatorSet{ + 1: vals, + 2: vals, + 3: newVals, + }, + false, + true, + }, + } + + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) + defer cancel() + logger := log.NewTestingLogger(t) + + mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) + mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), + light.Logger(logger), + ) + if tc.initErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) + if tc.verifyErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } + + }) + t.Run("LargeBisectionVerification", func(t *testing.T) { + // start from a large light block to make sure that the pivot height doesn't select a height outside + // the appropriate range + + numBlocks := int64(300) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(t, chainID, numBlocks, 101, 2, bTime) + + lastBlock := &types.LightBlock{SignedHeader: mockHeaders[numBlocks], ValidatorSet: mockVals[numBlocks]} + mockNode := &provider_mocks.Provider{} + mockNode.On("LightBlock", mock.Anything, numBlocks). + Return(lastBlock, nil) + + mockNode.On("LightBlock", mock.Anything, int64(200)). + Return(&types.LightBlock{SignedHeader: mockHeaders[200], ValidatorSet: mockVals[200]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(256)). + Return(&types.LightBlock{SignedHeader: mockHeaders[256], ValidatorSet: mockVals[256]}, nil) + + mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) + require.NoError(t, err) + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: trustedLightBlock.Height, + Hash: trustedLightBlock.Hash(), + }, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), + ) + require.NoError(t, err) + h, err := c.Update(ctx, bTime.Add(300*time.Minute)) + assert.NoError(t, err) + height, err := c.LastTrustedHeight() + require.NoError(t, err) + require.Equal(t, numBlocks, height) + h2, err := mockNode.LightBlock(ctx, numBlocks) + require.NoError(t, err) + assert.Equal(t, h, h2) + mockNode.AssertExpectations(t) + }) + t.Run("BisectionBetweenTrustedHeaders", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: 1, + Hash: h1.Hash(), + }, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.SkippingVerification(light.DefaultTrustLevel), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + require.NoError(t, err) + + // confirm that the client already doesn't have the light block + _, err = c.TrustedLightBlock(2) + require.Error(t, err) + + // verify using bisection the light block between the two trusted light blocks + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) + assert.NoError(t, err) + mockFullNode.AssertExpectations(t) + }) + t.Run("Cleanup", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) + + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + _, err = c.TrustedLightBlock(1) + require.NoError(t, err) + + err = c.Cleanup() + require.NoError(t, err) + + // Check no light blocks exist after Cleanup. + l, err := c.TrustedLightBlock(1) + assert.Error(t, err) + assert.Nil(t, l) + mockFullNode.AssertExpectations(t) + }) + t.Run("RestoresTrustedHeaderAfterStartup", func(t *testing.T) { + // trustedHeader.Height == options.Height + + bctx, bcancel := context.WithCancel(context.Background()) + defer bcancel() - for _, tc := range testCases { - testCase := tc - t.Run(testCase.name, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + // 1. options.Hash == trustedHeader.Hash + t.Run("hashes should match", func(t *testing.T) { + ctx, cancel := context.WithCancel(bctx) defer cancel() logger := log.NewTestingLogger(t) - mockNode := mockNodeFromHeadersAndVals(testCase.otherHeaders, testCase.vals) - mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + mockNode := &provider_mocks.Provider{} + trustedStore := dbs.New(dbm.NewMemDB()) + err := trustedStore.SaveLightBlock(l1) + require.NoError(t, err) + c, err := light.NewClient( ctx, chainID, trustOptions, mockNode, []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SequentialVerification(), + trustedStore, light.Logger(logger), ) - - if testCase.initErr { - require.Error(t, err) - return - } - require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) - if testCase.verifyErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.NotNil(t, l) + assert.Equal(t, l.Hash(), h1.Hash()) + assert.Equal(t, l.ValidatorSet.Hash(), h1.ValidatorsHash.Bytes()) mockNode.AssertExpectations(t) }) - } -} - -func TestClient_SkippingVerification(t *testing.T) { - // required for 2nd test case - newKeys := genPrivKeys(4) - newVals := newKeys.ToValidators(10, 1) - - // 1/3+ of vals, 2/3- of newVals - transitKeys := keys.Extend(3) - transitVals := transitKeys.ToValidators(10, 1) - - testCases := []struct { - name string - otherHeaders map[int64]*types.SignedHeader // all except ^ - vals map[int64]*types.ValidatorSet - initErr bool - verifyErr bool - }{ - { - "good", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // last header (3/3 signed) - 3: h3, - }, - valSet, - false, - false, - }, - { - "good, but val set changes by 2/3 (1/3 of vals is still present)", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - 3: transitKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, transitVals, transitVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(transitKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: transitVals, - }, - false, - false, - }, - { - "good, but val set changes 100% at height 2", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // interim header (3/3 signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), - // last header (0/4 of the original val set signed) - 3: newKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, - }, - false, - false, - }, - { - "bad: last header signed by newVals, interim header has no signers", - map[int64]*types.SignedHeader{ - // trusted header - 1: h1, - // last header (0/4 of the original val set signed) - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(1*time.Hour), nil, vals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, 0), - // last header (0/4 of the original val set signed) - 3: newKeys.GenSignedHeader(chainID, 3, bTime.Add(2*time.Hour), nil, newVals, newVals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(newKeys)), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - 3: newVals, - }, - false, - true, - }, - } - - bctx, bcancel := context.WithCancel(context.Background()) - defer bcancel() - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { + // 2. options.Hash != trustedHeader.Hash + t.Run("hashes should not match", func(t *testing.T) { ctx, cancel := context.WithCancel(bctx) defer cancel() + + trustedStore := dbs.New(dbm.NewMemDB()) + err := trustedStore.SaveLightBlock(l1) + require.NoError(t, err) + logger := log.NewTestingLogger(t) - mockNode := mockNodeFromHeadersAndVals(tc.otherHeaders, tc.vals) - mockNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + // header1 != h1 + header1 := keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, + hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) + mockNode := &provider_mocks.Provider{} + c, err := light.NewClient( ctx, chainID, - trustOptions, + light.TrustOptions{ + Period: 4 * time.Hour, + Height: 1, + Hash: header1.Hash(), + }, mockNode, []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), + trustedStore, light.Logger(logger), ) - if tc.initErr { - require.Error(t, err) - return - } - require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(3*time.Hour)) - if tc.verifyErr { - assert.Error(t, err) - } else { - assert.NoError(t, err) + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + if assert.NotNil(t, l) { + // client take the trusted store and ignores the trusted options + assert.Equal(t, l.Hash(), l1.Hash()) + assert.NoError(t, l.ValidateBasic(chainID)) } + mockNode.AssertExpectations(t) }) - } - -} - -// start from a large light block to make sure that the pivot height doesn't select a height outside -// the appropriate range -func TestClientLargeBisectionVerification(t *testing.T) { - numBlocks := int64(300) - mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, numBlocks, 101, 2, bTime) - - lastBlock := &types.LightBlock{SignedHeader: mockHeaders[numBlocks], ValidatorSet: mockVals[numBlocks]} - mockNode := &provider_mocks.Provider{} - mockNode.On("LightBlock", mock.Anything, numBlocks). - Return(lastBlock, nil) - - mockNode.On("LightBlock", mock.Anything, int64(200)). - Return(&types.LightBlock{SignedHeader: mockHeaders[200], ValidatorSet: mockVals[200]}, nil) - - mockNode.On("LightBlock", mock.Anything, int64(256)). - Return(&types.LightBlock{SignedHeader: mockHeaders[256], ValidatorSet: mockVals[256]}, nil) - - mockNode.On("LightBlock", mock.Anything, int64(0)).Return(lastBlock, nil) - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - trustedLightBlock, err := mockNode.LightBlock(ctx, int64(200)) - require.NoError(t, err) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ - Period: 4 * time.Hour, - Height: trustedLightBlock.Height, - Hash: trustedLightBlock.Hash(), - }, - mockNode, - []provider.Provider{mockNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), - ) - require.NoError(t, err) - h, err := c.Update(ctx, bTime.Add(300*time.Minute)) - assert.NoError(t, err) - height, err := c.LastTrustedHeight() - require.NoError(t, err) - require.Equal(t, numBlocks, height) - h2, err := mockNode.LightBlock(ctx, numBlocks) - require.NoError(t, err) - assert.Equal(t, h, h2) - mockNode.AssertExpectations(t) -} - -func TestClientBisectionBetweenTrustedHeaders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) - c, err := light.NewClient( - ctx, - chainID, - light.TrustOptions{ - Period: 4 * time.Hour, - Height: 1, - Hash: h1.Hash(), - }, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.SkippingVerification(light.DefaultTrustLevel), - ) - require.NoError(t, err) - - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - // confirm that the client already doesn't have the light block - _, err = c.TrustedLightBlock(2) - require.Error(t, err) - - // verify using bisection the light block between the two trusted light blocks - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour)) - assert.NoError(t, err) - mockFullNode.AssertExpectations(t) -} - -func TestClient_Cleanup(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - _, err = c.TrustedLightBlock(1) - require.NoError(t, err) - - err = c.Cleanup() - require.NoError(t, err) - - // Check no light blocks exist after Cleanup. - l, err := c.TrustedLightBlock(1) - assert.Error(t, err) - assert.Nil(t, l) - mockFullNode.AssertExpectations(t) -} - -// trustedHeader.Height == options.Height -func TestClientRestoresTrustedHeaderAfterStartup(t *testing.T) { - bctx, bcancel := context.WithCancel(context.Background()) - defer bcancel() - - // 1. options.Hash == trustedHeader.Hash - t.Run("hashes should match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) + }) + t.Run("Update", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - logger := log.NewTestingLogger(t) + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(3)).Return(l3, nil) - mockNode := &provider_mocks.Provider{} - trustedStore := dbs.New(dbm.NewMemDB()) - err := trustedStore.SaveLightBlock(l1) - require.NoError(t, err) + logger := log.NewTestingLogger(t) c, err := light.NewClient( ctx, chainID, trustOptions, - mockNode, - []provider.Provider{mockNode}, - trustedStore, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), light.Logger(logger), ) require.NoError(t, err) - l, err := c.TrustedLightBlock(1) + // should result in downloading & verifying header #3 + l, err := c.Update(ctx, bTime.Add(2*time.Hour)) assert.NoError(t, err) - assert.NotNil(t, l) - assert.Equal(t, l.Hash(), h1.Hash()) - assert.Equal(t, l.ValidatorSet.Hash(), h1.ValidatorsHash.Bytes()) - mockNode.AssertExpectations(t) + if assert.NotNil(t, l) { + assert.EqualValues(t, 3, l.Height) + assert.NoError(t, l.ValidateBasic(chainID)) + } + mockFullNode.AssertExpectations(t) }) - // 2. options.Hash != trustedHeader.Hash - t.Run("hashes should not match", func(t *testing.T) { - ctx, cancel := context.WithCancel(bctx) + t.Run("Concurrency", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) defer cancel() - - trustedStore := dbs.New(dbm.NewMemDB()) - err := trustedStore.SaveLightBlock(l1) - require.NoError(t, err) - logger := log.NewTestingLogger(t) - // header1 != h1 - header1 := keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Hour), nil, vals, vals, - hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) - mockNode := &provider_mocks.Provider{} - + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, int64(2)).Return(l2, nil) + mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) c, err := light.NewClient( ctx, chainID, - light.TrustOptions{ - Period: 4 * time.Hour, - Height: 1, - Hash: header1.Hash(), - }, - mockNode, - []provider.Provider{mockNode}, - trustedStore, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), light.Logger(logger), ) require.NoError(t, err) - l, err := c.TrustedLightBlock(1) - assert.NoError(t, err) - if assert.NotNil(t, l) { - // client take the trusted store and ignores the trusted options - assert.Equal(t, l.Hash(), l1.Hash()) - assert.NoError(t, l.ValidateBasic(chainID)) - } - mockNode.AssertExpectations(t) - }) -} - -func TestClient_Update(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(0)).Return(l3, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(3)).Return(l3, nil) - - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - - // should result in downloading & verifying header #3 - l, err := c.Update(ctx, bTime.Add(2*time.Hour)) - assert.NoError(t, err) - if assert.NotNil(t, l) { - assert.EqualValues(t, 3, l.Height) - assert.NoError(t, l.ValidateBasic(chainID)) - } - mockFullNode.AssertExpectations(t) -} - -func TestClient_Concurrency(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, int64(2)).Return(l2, nil) - mockFullNode.On("LightBlock", mock.Anything, int64(1)).Return(l1, nil) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) - require.NoError(t, err) + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + require.NoError(t, err) - var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) - go func() { - defer wg.Done() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() - // NOTE: Cleanup, Stop, VerifyLightBlockAtHeight and Verify are not supposed - // to be concurrently safe. + // NOTE: Cleanup, Stop, VerifyLightBlockAtHeight and Verify are not supposed + // to be concurrently safe. - assert.Equal(t, chainID, c.ChainID()) + assert.Equal(t, chainID, c.ChainID()) - _, err := c.LastTrustedHeight() - assert.NoError(t, err) + _, err := c.LastTrustedHeight() + assert.NoError(t, err) - _, err = c.FirstTrustedHeight() - assert.NoError(t, err) + _, err = c.FirstTrustedHeight() + assert.NoError(t, err) - l, err := c.TrustedLightBlock(1) - assert.NoError(t, err) - assert.NotNil(t, l) - }() - } + l, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.NotNil(t, l) + }() + } - wg.Wait() - mockFullNode.AssertExpectations(t) -} + wg.Wait() + mockFullNode.AssertExpectations(t) + }) + t.Run("AddProviders", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func TestClient_AddProviders(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - }, valSet) - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) + mockFullNode := mockNodeFromHeadersAndVals(map[int64]*types.SignedHeader{ + 1: h1, + 2: h2, + }, valSet) + logger := log.NewTestingLogger(t) - closeCh := make(chan struct{}) - go func() { - // run verification concurrently to make sure it doesn't dead lock - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) require.NoError(t, err) - close(closeCh) - }() - - // NOTE: the light client doesn't check uniqueness of providers - c.AddProvider(mockFullNode) - require.Len(t, c.Witnesses(), 2) - select { - case <-closeCh: - case <-time.After(5 * time.Second): - t.Fatal("concurent light block verification failed to finish in 5s") - } - mockFullNode.AssertExpectations(t) -} -func TestClientReplacesPrimaryWithWitnessIfPrimaryIsUnavailable(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) - - mockDeadNode := &provider_mocks.Provider{} - mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) - - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockDeadNode, - []provider.Provider{mockDeadNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - - require.NoError(t, err) - _, err = c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - // the primary should no longer be the deadNode - assert.NotEqual(t, c.Primary(), mockDeadNode) + closeCh := make(chan struct{}) + go func() { + // run verification concurrently to make sure it doesn't dead lock + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + require.NoError(t, err) + close(closeCh) + }() - // we should still have the dead node as a witness because it - // hasn't repeatedly been unresponsive yet - assert.Equal(t, 2, len(c.Witnesses())) - mockDeadNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} + // NOTE: the light client doesn't check uniqueness of providers + c.AddProvider(mockFullNode) + require.Len(t, c.Witnesses(), 2) + select { + case <-closeCh: + case <-time.After(5 * time.Second): + t.Fatal("concurent light block verification failed to finish in 5s") + } + mockFullNode.AssertExpectations(t) + }) + t.Run("ReplacesPrimaryWithWitnessIfPrimaryIsUnavailable", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() -func TestClientReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - mockFullNode := &provider_mocks.Provider{} - mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) - - logger := log.NewTestingLogger(t) - - mockDeadNode := &provider_mocks.Provider{} - mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockDeadNode, - []provider.Provider{mockDeadNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - _, err = c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - - // we should still have the dead node as a witness because it - // hasn't repeatedly been unresponsive yet - assert.Equal(t, 2, len(c.Witnesses())) - mockDeadNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) -func TestClient_BackwardsVerification(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) - { - headers, vals, _ := genLightBlocksWithKeys(chainID, 9, 3, 0, bTime) - delete(headers, 1) - delete(headers, 2) - delete(vals, 1) - delete(vals, 2) - mockLargeFullNode := mockNodeFromHeadersAndVals(headers, vals) - trustHeader, _ := mockLargeFullNode.LightBlock(ctx, 6) + logger := log.NewTestingLogger(t) c, err := light.NewClient( ctx, chainID, - light.TrustOptions{ - Period: 4 * time.Minute, - Height: trustHeader.Height, - Hash: trustHeader.Hash(), - }, - mockLargeFullNode, - []provider.Provider{mockLargeFullNode}, + trustOptions, + mockDeadNode, + []provider.Provider{mockDeadNode, mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(logger), ) - require.NoError(t, err) - // 1) verify before the trusted header using backwards => expect no error - h, err := c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) require.NoError(t, err) - if assert.NotNil(t, h) { - assert.EqualValues(t, 5, h.Height) - } - - // 2) untrusted header is expired but trusted header is not => expect no error - h, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(8*time.Minute)) - assert.NoError(t, err) - assert.NotNil(t, h) + _, err = c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) - // 3) already stored headers should return the header without error - h, err = c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) - assert.NoError(t, err) - assert.NotNil(t, h) + // the primary should no longer be the deadNode + assert.NotEqual(t, c.Primary(), mockDeadNode) - // 4a) First verify latest header - _, err = c.VerifyLightBlockAtHeight(ctx, 9, bTime.Add(9*time.Minute)) - require.NoError(t, err) + // we should still have the dead node as a witness because it + // hasn't repeatedly been unresponsive yet + assert.Equal(t, 2, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("ReplacesPrimaryWithWitnessIfPrimaryDoesntHaveBlock", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - // 4b) Verify backwards using bisection => expect no error - _, err = c.VerifyLightBlockAtHeight(ctx, 7, bTime.Add(9*time.Minute)) - assert.NoError(t, err) - // shouldn't have verified this header in the process - _, err = c.TrustedLightBlock(8) - assert.Error(t, err) + mockFullNode := &provider_mocks.Provider{} + mockFullNode.On("LightBlock", mock.Anything, mock.Anything).Return(l1, nil) - // 5) Try bisection method, but closest header (at 7) has expired - // so expect error - _, err = c.VerifyLightBlockAtHeight(ctx, 8, bTime.Add(12*time.Minute)) - assert.Error(t, err) - mockLargeFullNode.AssertExpectations(t) + logger := log.NewTestingLogger(t) - } - { - // 8) provides incorrect hash - headers := map[int64]*types.SignedHeader{ - 2: keys.GenSignedHeader(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), - 3: h3, - } - vals := valSet - mockNode := mockNodeFromHeadersAndVals(headers, vals) + mockDeadNode := &provider_mocks.Provider{} + mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) c, err := light.NewClient( ctx, chainID, - light.TrustOptions{ - Period: 1 * time.Hour, - Height: 3, - Hash: h3.Hash(), - }, - mockNode, - []provider.Provider{mockNode}, + trustOptions, + mockDeadNode, + []provider.Provider{mockDeadNode, mockFullNode}, dbs.New(dbm.NewMemDB()), light.Logger(logger), ) require.NoError(t, err) + _, err = c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour).Add(1*time.Second)) - assert.Error(t, err) - mockNode.AssertExpectations(t) - } -} + // we should still have the dead node as a witness because it + // hasn't repeatedly been unresponsive yet + assert.Equal(t, 2, len(c.Witnesses())) + mockDeadNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("BackwardsVerification", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) -func TestClient_NewClientFromTrustedStore(t *testing.T) { - // 1) Initiate DB and fill with a "trusted" header - db := dbs.New(dbm.NewMemDB()) - err := db.SaveLightBlock(l1) - require.NoError(t, err) - mockNode := &provider_mocks.Provider{} - - c, err := light.NewClientFromTrustedStore( - chainID, - trustPeriod, - mockNode, - []provider.Provider{mockNode}, - db, - ) - require.NoError(t, err) + { + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 9, 3, 0, bTime) + delete(headers, 1) + delete(headers, 2) + delete(vals, 1) + delete(vals, 2) + mockLargeFullNode := mockNodeFromHeadersAndVals(headers, vals) + trustHeader, _ := mockLargeFullNode.LightBlock(ctx, 6) - // 2) Check light block exists - h, err := c.TrustedLightBlock(1) - assert.NoError(t, err) - assert.EqualValues(t, l1.Height, h.Height) - mockNode.AssertExpectations(t) -} + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 4 * time.Minute, + Height: trustHeader.Height, + Hash: trustHeader.Hash(), + }, + mockLargeFullNode, + []provider.Provider{mockLargeFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) -func TestClientRemovesWitnessIfItSendsUsIncorrectHeader(t *testing.T) { - logger := log.NewTestingLogger(t) + // 1) verify before the trusted header using backwards => expect no error + h, err := c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) + require.NoError(t, err) + if assert.NotNil(t, h) { + assert.EqualValues(t, 5, h.Height) + } - // different headers hash then primary plus less than 1/3 signed (no fork) - headers1 := map[int64]*types.SignedHeader{ - 1: h1, - 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, - hash("app_hash2"), hash("cons_hash"), hash("results_hash"), - len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), - } - vals1 := map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - } - mockBadNode1 := mockNodeFromHeadersAndVals(headers1, vals1) - mockBadNode1.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + // 2) untrusted header is expired but trusted header is not => expect no error + h, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(8*time.Minute)) + assert.NoError(t, err) + assert.NotNil(t, h) - // header is empty - headers2 := map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - } - vals2 := map[int64]*types.ValidatorSet{ - 1: vals, - 2: vals, - } - mockBadNode2 := mockNodeFromHeadersAndVals(headers2, vals2) - mockBadNode2.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) + // 3) already stored headers should return the header without error + h, err = c.VerifyLightBlockAtHeight(ctx, 5, bTime.Add(6*time.Minute)) + assert.NoError(t, err) + assert.NotNil(t, h) - mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) + // 4a) First verify latest header + _, err = c.VerifyLightBlockAtHeight(ctx, 9, bTime.Add(9*time.Minute)) + require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + // 4b) Verify backwards using bisection => expect no error + _, err = c.VerifyLightBlockAtHeight(ctx, 7, bTime.Add(9*time.Minute)) + assert.NoError(t, err) + // shouldn't have verified this header in the process + _, err = c.TrustedLightBlock(8) + assert.Error(t, err) - lb1, _ := mockBadNode1.LightBlock(ctx, 2) - require.NotEqual(t, lb1.Hash(), l1.Hash()) + // 5) Try bisection method, but closest header (at 7) has expired + // so expect error + _, err = c.VerifyLightBlockAtHeight(ctx, 8, bTime.Add(12*time.Minute)) + assert.Error(t, err) + mockLargeFullNode.AssertExpectations(t) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockBadNode1, mockBadNode2}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - // witness should have behaved properly -> no error - require.NoError(t, err) - assert.EqualValues(t, 2, len(c.Witnesses())) - - // witness behaves incorrectly -> removed from list, no error - l, err := c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) - assert.NoError(t, err) - assert.EqualValues(t, 1, len(c.Witnesses())) - // light block should still be verified - assert.EqualValues(t, 2, l.Height) - - // remaining witnesses don't have light block -> error - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - if assert.Error(t, err) { - assert.Equal(t, light.ErrFailedHeaderCrossReferencing, err) - } - // witness does not have a light block -> left in the list - assert.EqualValues(t, 1, len(c.Witnesses())) - mockBadNode1.AssertExpectations(t) - mockBadNode2.AssertExpectations(t) -} + } + { + // 8) provides incorrect hash + headers := map[int64]*types.SignedHeader{ + 2: keys.GenSignedHeader(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash23"), hash("results_hash30"), 0, len(keys)), + 3: h3, + } + vals := valSet + mockNode := mockNodeFromHeadersAndVals(headers, vals) + c, err := light.NewClient( + ctx, + chainID, + light.TrustOptions{ + Period: 1 * time.Hour, + Height: 3, + Hash: h3.Hash(), + }, + mockNode, + []provider.Provider{mockNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) -func TestClient_TrustedValidatorSet(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(1*time.Hour).Add(1*time.Second)) + assert.Error(t, err) + mockNode.AssertExpectations(t) + } + }) + t.Run("NewClientFromTrustedStore", func(t *testing.T) { + // 1) Initiate DB and fill with a "trusted" header + db := dbs.New(dbm.NewMemDB()) + err := db.SaveLightBlock(l1) + require.NoError(t, err) + mockNode := &provider_mocks.Provider{} + + c, err := light.NewClientFromTrustedStore( + chainID, + trustPeriod, + mockNode, + []provider.Provider{mockNode}, + db, + ) + require.NoError(t, err) - logger := log.NewTestingLogger(t) + // 2) Check light block exists + h, err := c.TrustedLightBlock(1) + assert.NoError(t, err) + assert.EqualValues(t, l1.Height, h.Height) + mockNode.AssertExpectations(t) + }) + t.Run("RemovesWitnessIfItSendsUsIncorrectHeader", func(t *testing.T) { + logger := log.NewTestingLogger(t) - differentVals, _ := factory.RandValidatorSet(ctx, t, 10, 100) - mockBadValSetNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ + // different headers hash then primary plus less than 1/3 signed (no fork) + headers1 := map[int64]*types.SignedHeader{ 1: h1, - // 3/3 signed, but validator set at height 2 below is invalid -> witness - // should be removed. - 2: keys.GenSignedHeaderLastBlockID(chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + 2: keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, hash("app_hash2"), hash("cons_hash"), hash("results_hash"), - 0, len(keys), types.BlockID{Hash: h1.Hash()}), - }, - map[int64]*types.ValidatorSet{ - 1: vals, - 2: differentVals, - }) - mockFullNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ - 1: h1, - 2: h2, - }, - map[int64]*types.ValidatorSet{ + len(keys), len(keys), types.BlockID{Hash: h1.Hash()}), + } + vals1 := map[int64]*types.ValidatorSet{ 1: vals, 2: vals, - }) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockBadValSetNode, mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - ) - require.NoError(t, err) - assert.Equal(t, 2, len(c.Witnesses())) - - _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour).Add(1*time.Second)) - assert.NoError(t, err) - assert.Equal(t, 1, len(c.Witnesses())) - mockBadValSetNode.AssertExpectations(t) - mockFullNode.AssertExpectations(t) -} + } + mockBadNode1 := mockNodeFromHeadersAndVals(headers1, vals1) + mockBadNode1.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) -func TestClientPrunesHeadersAndValidatorSets(t *testing.T) { - mockFullNode := mockNodeFromHeadersAndVals( - map[int64]*types.SignedHeader{ + // header is empty + headers2 := map[int64]*types.SignedHeader{ 1: h1, - 3: h3, - 0: h3, - }, - map[int64]*types.ValidatorSet{ + 2: h2, + } + vals2 := map[int64]*types.ValidatorSet{ 1: vals, - 3: vals, - 0: vals, - }) + 2: vals, + } + mockBadNode2 := mockNodeFromHeadersAndVals(headers2, vals2) + mockBadNode2.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrLightBlockNotFound) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - logger := log.NewTestingLogger(t) - - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockFullNode, - []provider.Provider{mockFullNode}, - dbs.New(dbm.NewMemDB()), - light.Logger(logger), - light.PruningSize(1), - ) - require.NoError(t, err) - _, err = c.TrustedLightBlock(1) - require.NoError(t, err) + mockFullNode := mockNodeFromHeadersAndVals(headerSet, valSet) - h, err := c.Update(ctx, bTime.Add(2*time.Hour)) - require.NoError(t, err) - require.Equal(t, int64(3), h.Height) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - _, err = c.TrustedLightBlock(1) - assert.Error(t, err) - mockFullNode.AssertExpectations(t) -} + lb1, _ := mockBadNode1.LightBlock(ctx, 2) + require.NotEqual(t, lb1.Hash(), l1.Hash()) -func TestClientEnsureValidHeadersAndValSets(t *testing.T) { - emptyValSet := &types.ValidatorSet{ - Validators: nil, - Proposer: nil, - } + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockBadNode1, mockBadNode2}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + // witness should have behaved properly -> no error + require.NoError(t, err) + assert.EqualValues(t, 2, len(c.Witnesses())) - testCases := []struct { - headers map[int64]*types.SignedHeader - vals map[int64]*types.ValidatorSet + // witness behaves incorrectly -> removed from list, no error + l, err := c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour)) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(c.Witnesses())) + // light block should still be verified + assert.EqualValues(t, 2, l.Height) + + // remaining witnesses don't have light block -> error + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + if assert.Error(t, err) { + assert.Equal(t, light.ErrFailedHeaderCrossReferencing, err) + } + // witness does not have a light block -> left in the list + assert.EqualValues(t, 1, len(c.Witnesses())) + mockBadNode1.AssertExpectations(t) + mockBadNode2.AssertExpectations(t) + }) + t.Run("TrustedValidatorSet", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - errorToThrow error - errorHeight int64 + logger := log.NewTestingLogger(t) - err bool - }{ - { - headers: map[int64]*types.SignedHeader{ + differentVals, _ := factory.RandValidatorSet(ctx, t, 10, 100) + mockBadValSetNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ 1: h1, - 3: h3, + // 3/3 signed, but validator set at height 2 below is invalid -> witness + // should be removed. + 2: keys.GenSignedHeaderLastBlockID(t, chainID, 2, bTime.Add(30*time.Minute), nil, vals, vals, + hash("app_hash2"), hash("cons_hash"), hash("results_hash"), + 0, len(keys), types.BlockID{Hash: h1.Hash()}), }, - vals: map[int64]*types.ValidatorSet{ + map[int64]*types.ValidatorSet{ 1: vals, - 3: vals, - }, - err: false, - }, - { - headers: map[int64]*types.SignedHeader{ + 2: differentVals, + }) + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ 1: h1, + 2: h2, }, - vals: map[int64]*types.ValidatorSet{ + map[int64]*types.ValidatorSet{ 1: vals, - }, - errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, - errorHeight: 3, - err: true, - }, - { - headers: map[int64]*types.SignedHeader{ - 1: h1, - }, - errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, - errorHeight: 3, - vals: valSet, - err: true, - }, - { - headers: map[int64]*types.SignedHeader{ + 2: vals, + }) + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockBadValSetNode, mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + ) + require.NoError(t, err) + assert.Equal(t, 2, len(c.Witnesses())) + + _, err = c.VerifyLightBlockAtHeight(ctx, 2, bTime.Add(2*time.Hour).Add(1*time.Second)) + assert.NoError(t, err) + assert.Equal(t, 1, len(c.Witnesses())) + mockBadValSetNode.AssertExpectations(t) + mockFullNode.AssertExpectations(t) + }) + t.Run("PrunesHeadersAndValidatorSets", func(t *testing.T) { + mockFullNode := mockNodeFromHeadersAndVals( + map[int64]*types.SignedHeader{ 1: h1, 3: h3, + 0: h3, }, - vals: map[int64]*types.ValidatorSet{ + map[int64]*types.ValidatorSet{ 1: vals, - 3: emptyValSet, - }, - err: true, - }, - } + 3: vals, + 0: vals, + }) - for i, tc := range testCases { - testCase := tc - t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + logger := log.NewTestingLogger(t) - mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) - if testCase.errorToThrow != nil { - mockBadNode.On("LightBlock", mock.Anything, testCase.errorHeight).Return(nil, testCase.errorToThrow) - } + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockFullNode, + []provider.Provider{mockFullNode}, + dbs.New(dbm.NewMemDB()), + light.Logger(logger), + light.PruningSize(1), + ) + require.NoError(t, err) + _, err = c.TrustedLightBlock(1) + require.NoError(t, err) - c, err := light.NewClient( - ctx, - chainID, - trustOptions, - mockBadNode, - []provider.Provider{mockBadNode, mockBadNode}, - dbs.New(dbm.NewMemDB()), - ) - require.NoError(t, err) + h, err := c.Update(ctx, bTime.Add(2*time.Hour)) + require.NoError(t, err) + require.Equal(t, int64(3), h.Height) - _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) - if testCase.err { - assert.Error(t, err) - } else { - assert.NoError(t, err) - } - mockBadNode.AssertExpectations(t) - }) - } + _, err = c.TrustedLightBlock(1) + assert.Error(t, err) + mockFullNode.AssertExpectations(t) + }) + t.Run("EnsureValidHeadersAndValSets", func(t *testing.T) { + emptyValSet := &types.ValidatorSet{ + Validators: nil, + Proposer: nil, + } + + testCases := []struct { + headers map[int64]*types.SignedHeader + vals map[int64]*types.ValidatorSet + + errorToThrow error + errorHeight int64 + + err bool + }{ + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + 3: vals, + }, + err: false, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + }, + errorToThrow: provider.ErrBadLightBlock{Reason: errors.New("nil header or vals")}, + errorHeight: 3, + vals: valSet, + err: true, + }, + { + headers: map[int64]*types.SignedHeader{ + 1: h1, + 3: h3, + }, + vals: map[int64]*types.ValidatorSet{ + 1: vals, + 3: emptyValSet, + }, + err: true, + }, + } + + for i, tc := range testCases { + testCase := tc + t.Run(fmt.Sprintf("case: %d", i), func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mockBadNode := mockNodeFromHeadersAndVals(testCase.headers, testCase.vals) + if testCase.errorToThrow != nil { + mockBadNode.On("LightBlock", mock.Anything, testCase.errorHeight).Return(nil, testCase.errorToThrow) + } + + c, err := light.NewClient( + ctx, + chainID, + trustOptions, + mockBadNode, + []provider.Provider{mockBadNode, mockBadNode}, + dbs.New(dbm.NewMemDB()), + ) + require.NoError(t, err) + + _, err = c.VerifyLightBlockAtHeight(ctx, 3, bTime.Add(2*time.Hour)) + if testCase.err { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + mockBadNode.AssertExpectations(t) + }) + } + }) } diff --git a/light/detector_test.go b/light/detector_test.go index f61d7f116..84b6f210c 100644 --- a/light/detector_test.go +++ b/light/detector_test.go @@ -35,7 +35,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, latestHeight, valSize, 2, bTime) forgedKeys := chainKeys[divergenceHeight-1].ChangeKeys(3) // we change 3 out of the 5 validators (still 2/5 remain) forgedVals := forgedKeys.ToValidators(2, 0) @@ -46,7 +46,7 @@ func TestLightClientAttackEvidence_Lunatic(t *testing.T) { primaryValidators[height] = witnessValidators[height] continue } - primaryHeaders[height] = forgedKeys.GenSignedHeader(chainID, height, bTime.Add(time.Duration(height)*time.Minute), + primaryHeaders[height] = forgedKeys.GenSignedHeader(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), nil, forgedVals, forgedVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(forgedKeys)) primaryValidators[height] = forgedVals } @@ -152,7 +152,7 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { // validators don't change in this network (however we still use a map just for convenience) primaryValidators = make(map[int64]*types.ValidatorSet, testCase.latestHeight) ) - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, testCase.latestHeight+1, valSize, 2, bTime) for height := int64(1); height <= testCase.latestHeight; height++ { if height < testCase.divergenceHeight { @@ -162,7 +162,7 @@ func TestLightClientAttackEvidence_Equivocation(t *testing.T) { } // we don't have a network partition so we will make 4/5 (greater than 2/3) malicious and vote again for // a different block (which we do by adding txs) - primaryHeaders[height] = chainKeys[height].GenSignedHeader(chainID, height, + primaryHeaders[height] = chainKeys[height].GenSignedHeader(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), []types.Tx{[]byte("abcd")}, witnessValidators[height], witnessValidators[height+1], hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(chainKeys[height])-1) @@ -246,7 +246,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { defer cancel() logger := log.NewTestingLogger(t) - witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(chainID, latestHeight, valSize, 2, bTime) + witnessHeaders, witnessValidators, chainKeys := genLightBlocksWithKeys(t, chainID, latestHeight, valSize, 2, bTime) for _, unusedHeader := range []int64{3, 5, 6, 8} { delete(witnessHeaders, unusedHeader) } @@ -262,7 +262,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { } forgedKeys := chainKeys[latestHeight].ChangeKeys(3) // we change 3 out of the 5 validators (still 2/5 remain) primaryValidators[forgedHeight] = forgedKeys.ToValidators(2, 0) - primaryHeaders[forgedHeight] = forgedKeys.GenSignedHeader( + primaryHeaders[forgedHeight] = forgedKeys.GenSignedHeader(t, chainID, forgedHeight, bTime.Add(time.Duration(latestHeight+1)*time.Minute), // 11 mins @@ -326,7 +326,7 @@ func TestLightClientAttackEvidence_ForwardLunatic(t *testing.T) { // to prove that there was an attack vals := chainKeys[latestHeight].ToValidators(2, 0) newLb := &types.LightBlock{ - SignedHeader: chainKeys[latestHeight].GenSignedHeader( + SignedHeader: chainKeys[latestHeight].GenSignedHeader(t, chainID, proofHeight, bTime.Add(time.Duration(proofHeight+1)*time.Minute), // 12 mins @@ -395,11 +395,11 @@ func TestClientDivergentTraces1(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - headers, vals, _ := genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 1, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(headers, vals) firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - headers, vals, _ = genLightBlocksWithKeys(chainID, 1, 5, 2, bTime) + headers, vals, _ = genLightBlocksWithKeys(t, chainID, 1, 5, 2, bTime) mockWitness := mockNodeFromHeadersAndVals(headers, vals) logger := log.NewTestingLogger(t) @@ -430,7 +430,7 @@ func TestClientDivergentTraces2(t *testing.T) { defer cancel() logger := log.NewTestingLogger(t) - headers, vals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + headers, vals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimaryNode := mockNodeFromHeadersAndVals(headers, vals) mockDeadNode := &provider_mocks.Provider{} mockDeadNode.On("LightBlock", mock.Anything, mock.Anything).Return(nil, provider.ErrNoResponse) @@ -465,7 +465,7 @@ func TestClientDivergentTraces3(t *testing.T) { logger := log.NewTestingLogger(t) // - primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) ctx, cancel := context.WithCancel(context.Background()) @@ -474,7 +474,7 @@ func TestClientDivergentTraces3(t *testing.T) { firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - mockHeaders, mockVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + mockHeaders, mockVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockHeaders[1] = primaryHeaders[1] mockVals[1] = primaryVals[1] mockWitness := mockNodeFromHeadersAndVals(mockHeaders, mockVals) @@ -508,7 +508,7 @@ func TestClientDivergentTraces4(t *testing.T) { logger := log.NewTestingLogger(t) // - primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + primaryHeaders, primaryVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) mockPrimary := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) ctx, cancel := context.WithCancel(context.Background()) @@ -517,7 +517,7 @@ func TestClientDivergentTraces4(t *testing.T) { firstBlock, err := mockPrimary.LightBlock(ctx, 1) require.NoError(t, err) - witnessHeaders, witnessVals, _ := genLightBlocksWithKeys(chainID, 2, 5, 2, bTime) + witnessHeaders, witnessVals, _ := genLightBlocksWithKeys(t, chainID, 2, 5, 2, bTime) primaryHeaders[2] = witnessHeaders[2] primaryVals[2] = witnessVals[2] mockWitness := mockNodeFromHeadersAndVals(primaryHeaders, primaryVals) diff --git a/light/helpers_test.go b/light/helpers_test.go index 1d25f9166..9f6147526 100644 --- a/light/helpers_test.go +++ b/light/helpers_test.go @@ -1,9 +1,11 @@ package light_test import ( + "testing" "time" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/tendermint/tendermint/crypto" "github.com/tendermint/tendermint/crypto/ed25519" "github.com/tendermint/tendermint/crypto/tmhash" @@ -74,7 +76,9 @@ func (pkz privKeys) ToValidators(init, inc int64) *types.ValidatorSet { } // signHeader properly signs the header with all keys from first to last exclusive. -func (pkz privKeys) signHeader(header *types.Header, valSet *types.ValidatorSet, first, last int) *types.Commit { +func (pkz privKeys) signHeader(t testing.TB, header *types.Header, valSet *types.ValidatorSet, first, last int) *types.Commit { + t.Helper() + commitSigs := make([]types.CommitSig, len(pkz)) for i := 0; i < len(pkz); i++ { commitSigs[i] = types.NewCommitSigAbsent() @@ -87,15 +91,15 @@ func (pkz privKeys) signHeader(header *types.Header, valSet *types.ValidatorSet, // Fill in the votes we want. for i := first; i < last && i < len(pkz); i++ { - vote := makeVote(header, valSet, pkz[i], blockID) + vote := makeVote(t, header, valSet, pkz[i], blockID) commitSigs[vote.ValidatorIndex] = vote.CommitSig() } return types.NewCommit(header.Height, 1, blockID, commitSigs) } -func makeVote(header *types.Header, valset *types.ValidatorSet, - key crypto.PrivKey, blockID types.BlockID) *types.Vote { +func makeVote(t testing.TB, header *types.Header, valset *types.ValidatorSet, key crypto.PrivKey, blockID types.BlockID) *types.Vote { + t.Helper() addr := key.PubKey().Address() idx, _ := valset.GetByAddress(addr) @@ -113,9 +117,7 @@ func makeVote(header *types.Header, valset *types.ValidatorSet, // Sign it signBytes := types.VoteSignBytes(header.ChainID, v) sig, err := key.Sign(signBytes) - if err != nil { - panic(err) - } + require.NoError(t, err) vote.Signature = sig @@ -143,26 +145,30 @@ func genHeader(chainID string, height int64, bTime time.Time, txs types.Txs, } // GenSignedHeader calls genHeader and signHeader and combines them into a SignedHeader. -func (pkz privKeys) GenSignedHeader(chainID string, height int64, bTime time.Time, txs types.Txs, +func (pkz privKeys) GenSignedHeader(t testing.TB, chainID string, height int64, bTime time.Time, txs types.Txs, valset, nextValset *types.ValidatorSet, appHash, consHash, resHash []byte, first, last int) *types.SignedHeader { + t.Helper() + header := genHeader(chainID, height, bTime, txs, valset, nextValset, appHash, consHash, resHash) return &types.SignedHeader{ Header: header, - Commit: pkz.signHeader(header, valset, first, last), + Commit: pkz.signHeader(t, header, valset, first, last), } } // GenSignedHeaderLastBlockID calls genHeader and signHeader and combines them into a SignedHeader. -func (pkz privKeys) GenSignedHeaderLastBlockID(chainID string, height int64, bTime time.Time, txs types.Txs, +func (pkz privKeys) GenSignedHeaderLastBlockID(t testing.TB, chainID string, height int64, bTime time.Time, txs types.Txs, valset, nextValset *types.ValidatorSet, appHash, consHash, resHash []byte, first, last int, lastBlockID types.BlockID) *types.SignedHeader { + t.Helper() + header := genHeader(chainID, height, bTime, txs, valset, nextValset, appHash, consHash, resHash) header.LastBlockID = lastBlockID return &types.SignedHeader{ Header: header, - Commit: pkz.signHeader(header, valset, first, last), + Commit: pkz.signHeader(t, header, valset, first, last), } } @@ -175,14 +181,14 @@ func (pkz privKeys) ChangeKeys(delta int) privKeys { // blocks to height. BlockIntervals are in per minute. // NOTE: Expected to have a large validator set size ~ 100 validators. func genLightBlocksWithKeys( + t testing.TB, chainID string, numBlocks int64, valSize int, valVariation float32, - bTime time.Time) ( - map[int64]*types.SignedHeader, - map[int64]*types.ValidatorSet, - map[int64]privKeys) { + bTime time.Time, +) (map[int64]*types.SignedHeader, map[int64]*types.ValidatorSet, map[int64]privKeys) { + t.Helper() var ( headers = make(map[int64]*types.SignedHeader, numBlocks) @@ -201,7 +207,7 @@ func genLightBlocksWithKeys( keymap[2] = newKeys // genesis header and vals - lastHeader := keys.GenSignedHeader(chainID, 1, bTime.Add(1*time.Minute), nil, + lastHeader := keys.GenSignedHeader(t, chainID, 1, bTime.Add(1*time.Minute), nil, keys.ToValidators(2, 0), newKeys.ToValidators(2, 0), hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) currentHeader := lastHeader @@ -214,7 +220,7 @@ func genLightBlocksWithKeys( valVariationInt = int(totalVariation) totalVariation = -float32(valVariationInt) newKeys = keys.ChangeKeys(valVariationInt) - currentHeader = keys.GenSignedHeaderLastBlockID(chainID, height, bTime.Add(time.Duration(height)*time.Minute), + currentHeader = keys.GenSignedHeaderLastBlockID(t, chainID, height, bTime.Add(time.Duration(height)*time.Minute), nil, keys.ToValidators(2, 0), newKeys.ToValidators(2, 0), hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys), types.BlockID{Hash: lastHeader.Hash()}) diff --git a/light/verifier_test.go b/light/verifier_test.go index 0432c130d..5a2019e21 100644 --- a/light/verifier_test.go +++ b/light/verifier_test.go @@ -28,7 +28,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) ) @@ -51,7 +51,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // different chainID -> error 1: { - keys.GenSignedHeader("different-chainID", nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, "different-chainID", nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -61,7 +61,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is before old header's time -> error 2: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(-1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(-1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 4 * time.Hour, @@ -71,7 +71,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is from the future -> error 3: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(3*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(3*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -81,7 +81,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // new header's time is from the future, but it's acceptable (< maxClockDrift) -> no error 4: { - keys.GenSignedHeader(chainID, nextHeight, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(2*time.Hour).Add(maxClockDrift).Add(-1*time.Millisecond), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, @@ -92,7 +92,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 3/3 signed -> no error 5: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -102,7 +102,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 2/3 signed -> no error 6: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 1, len(keys)), vals, 3 * time.Hour, @@ -112,7 +112,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // 1/3 signed -> error 7: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), vals, 3 * time.Hour, @@ -122,7 +122,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // vals does not match with what we have -> error 8: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, keys.ToValidators(10, 1), vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, keys.ToValidators(10, 1), vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 3 * time.Hour, @@ -132,7 +132,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // vals are inconsistent with newHeader -> error 9: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 3 * time.Hour, @@ -142,7 +142,7 @@ func TestVerifyAdjacentHeaders(t *testing.T) { }, // old header has expired -> error 10: { - keys.GenSignedHeader(chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, nextHeight, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), keys.ToValidators(10, 1), 1 * time.Hour, @@ -180,7 +180,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) // 30, 40, 50 @@ -206,7 +206,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }{ // 3/3 new vals signed, 3/3 old vals present -> no error 0: { - keys.GenSignedHeader(chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 3, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)), vals, 3 * time.Hour, @@ -216,7 +216,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 2/3 new vals signed, 3/3 old vals present -> no error 1: { - keys.GenSignedHeader(chainID, 4, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 4, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 1, len(keys)), vals, 3 * time.Hour, @@ -226,7 +226,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 1/3 new vals signed, 3/3 old vals present -> error 2: { - keys.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, vals, vals, + keys.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), len(keys)-1, len(keys)), vals, 3 * time.Hour, @@ -236,7 +236,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, 2/3 old vals present -> no error 3: { - twoThirds.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, twoThirdsVals, twoThirdsVals, + twoThirds.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, twoThirdsVals, twoThirdsVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(twoThirds)), twoThirdsVals, 3 * time.Hour, @@ -246,7 +246,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, 1/3 old vals present -> no error 4: { - oneThird.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, oneThirdVals, oneThirdVals, + oneThird.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, oneThirdVals, oneThirdVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(oneThird)), oneThirdVals, 3 * time.Hour, @@ -256,7 +256,7 @@ func TestVerifyNonAdjacentHeaders(t *testing.T) { }, // 3/3 new vals signed, less than 1/3 old vals present -> error 5: { - lessThanOneThird.GenSignedHeader(chainID, 5, bTime.Add(1*time.Hour), nil, lessThanOneThirdVals, lessThanOneThirdVals, + lessThanOneThird.GenSignedHeader(t, chainID, 5, bTime.Add(1*time.Hour), nil, lessThanOneThirdVals, lessThanOneThirdVals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(lessThanOneThird)), lessThanOneThirdVals, 3 * time.Hour, @@ -296,7 +296,7 @@ func TestVerifyReturnsErrorIfTrustLevelIsInvalid(t *testing.T) { // 20, 30, 40, 50 - the first 3 don't have 2/3, the last 3 do! vals = keys.ToValidators(20, 10) bTime, _ = time.Parse(time.RFC3339, "2006-01-02T15:04:05Z") - header = keys.GenSignedHeader(chainID, lastHeight, bTime, nil, vals, vals, + header = keys.GenSignedHeader(t, chainID, lastHeight, bTime, nil, vals, vals, hash("app_hash"), hash("cons_hash"), hash("results_hash"), 0, len(keys)) )