From 1905ef7ca8ab4a8a8bfc1bba9b020f6be6061d2f Mon Sep 17 00:00:00 2001 From: Anton Kaliaev Date: Wed, 22 Jan 2020 20:26:47 +0400 Subject: [PATCH] lite2: improve auto update (#4334) * lite2: advance to latest header without any exponential steps rename autoUpdate to autoUpdateRoutine * lite2: wait in Cleanup until goroutines finished running --- lite2/client.go | 47 ++++++++++++++++++++++--------------- lite2/client_test.go | 16 +++++-------- lite2/example_test.go | 12 +++++++--- lite2/provider/mock/mock.go | 6 +++++ 4 files changed, 49 insertions(+), 32 deletions(-) diff --git a/lite2/client.go b/lite2/client.go index 9d95817f0..ad7f8415f 100644 --- a/lite2/client.go +++ b/lite2/client.go @@ -3,6 +3,7 @@ package lite import ( "bytes" "fmt" + "sync" "time" "github.com/pkg/errors" @@ -154,6 +155,7 @@ type Client struct { updatePeriod time.Duration removeNoLongerTrustedHeadersPeriod time.Duration + routinesWaitGroup sync.WaitGroup confirmationFn func(action string) bool @@ -208,11 +210,13 @@ func NewClient( } if c.removeNoLongerTrustedHeadersPeriod > 0 { + c.routinesWaitGroup.Add(1) go c.removeNoLongerTrustedHeadersRoutine() } if c.updatePeriod > 0 { - go c.autoUpdate() + c.routinesWaitGroup.Add(1) + go c.autoUpdateRoutine() } return c, nil @@ -474,7 +478,7 @@ func (c *Client) VerifyHeaderAtHeight(height int64, now time.Time) (*types.Signe // provider, provider.ErrSignedHeaderNotFound / // provider.ErrValidatorSetNotFound error is returned. func (c *Client) VerifyHeader(newHeader *types.SignedHeader, newVals *types.ValidatorSet, now time.Time) error { - c.logger.Info("VerifyHeader", "height", newHeader.Hash(), "newVals", newVals.Hash()) + c.logger.Info("VerifyHeader", "height", newHeader.Hash(), "newVals", fmt.Sprintf("%X", newVals.Hash())) if c.trustedHeader.Height >= newHeader.Height { return errors.Errorf("header at more recent height #%d exists", c.trustedHeader.Height) @@ -507,8 +511,11 @@ func (c *Client) VerifyHeader(newHeader *types.SignedHeader, newVals *types.Vali return c.updateTrustedHeaderAndVals(newHeader, nextVals) } -// Cleanup removes all the data (headers and validator sets) stored. +// Cleanup removes all the data (headers and validator sets) stored. It blocks +// until internal routines are finished. Note: the client must be stopped at +// this point. func (c *Client) Cleanup() error { + c.routinesWaitGroup.Wait() c.logger.Info("Cleanup everything") return c.cleanup(0) } @@ -707,6 +714,8 @@ func (c *Client) compareNewHeaderWithRandomAlternative(h *types.SignedHeader) er } func (c *Client) removeNoLongerTrustedHeadersRoutine() { + defer c.routinesWaitGroup.Done() + ticker := time.NewTicker(c.removeNoLongerTrustedHeadersPeriod) defer ticker.Stop() @@ -760,14 +769,16 @@ func (c *Client) RemoveNoLongerTrustedHeaders(now time.Time) { } } -func (c *Client) autoUpdate() { +func (c *Client) autoUpdateRoutine() { + defer c.routinesWaitGroup.Done() + ticker := time.NewTicker(c.updatePeriod) defer ticker.Stop() for { select { case <-ticker.C: - err := c.AutoUpdate(time.Now()) + err := c.Update(time.Now()) if err != nil { c.logger.Error("Error during auto update", "err", err) } @@ -777,12 +788,12 @@ func (c *Client) autoUpdate() { } } -// AutoUpdate attempts to advance the state making exponential steps (note: +// Update attempts to advance the state making exponential steps (note: // when SequentialVerification is being used, the client will still be // downloading all intermediate headers). // // Exposed for testing. -func (c *Client) AutoUpdate(now time.Time) error { +func (c *Client) Update(now time.Time) error { lastTrustedHeight, err := c.LastTrustedHeight() if err != nil { return errors.Wrap(err, "can't get last trusted height") @@ -793,20 +804,18 @@ func (c *Client) AutoUpdate(now time.Time) error { return nil } - var i int64 - for err == nil { - // exponential increment: 1, 2, 4, 8, 16, ... - height := lastTrustedHeight + int64(1< lastTrustedHeight { + err = c.VerifyHeader(latestHeader, latestVals, now) if err != nil { - if errors.Is(err, provider.ErrSignedHeaderNotFound) { - c.logger.Debug("No header yet", "at", height) - return nil - } - return errors.Wrapf(err, "failed to verify the header #%d", height) + return err } - c.logger.Info("Advanced to new state", "height", h.Height, "hash", h.Hash()) - i++ + + c.logger.Info("Advanced to new state", "height", latestHeader.Height, "hash", latestHeader.Hash()) } return nil diff --git a/lite2/client_test.go b/lite2/client_test.go index eea533208..3065b76c0 100644 --- a/lite2/client_test.go +++ b/lite2/client_test.go @@ -350,6 +350,7 @@ func TestClient_Cleanup(t *testing.T) { ) require.NoError(t, err) + c.Stop() c.Cleanup() // Check no headers exist after Cleanup. @@ -661,9 +662,9 @@ func TestClientRestoreTrustedHeaderAfterStartup3(t *testing.T) { } } -func TestClient_AutoUpdate(t *testing.T) { +func TestClient_Update(t *testing.T) { const ( - chainID = "TestClient_AutoUpdate" + chainID = "TestClient_Update" ) var ( @@ -707,16 +708,11 @@ func TestClient_AutoUpdate(t *testing.T) { require.NoError(t, err) defer c.Stop() - // should result in downloading & verifying headers #2 and #3 - err = c.AutoUpdate(bTime.Add(2 * time.Hour)) + // should result in downloading & verifying header #3 + err = c.Update(bTime.Add(2 * time.Hour)) require.NoError(t, err) - h, err := c.TrustedHeader(2, bTime.Add(2*time.Hour)) - assert.NoError(t, err) - require.NotNil(t, h) - assert.EqualValues(t, 2, h.Height) - - h, err = c.TrustedHeader(3, bTime.Add(2*time.Hour)) + h, err := c.TrustedHeader(3, bTime.Add(2*time.Hour)) assert.NoError(t, err) require.NotNil(t, h) assert.EqualValues(t, 3, h.Height) diff --git a/lite2/example_test.go b/lite2/example_test.go index 772012a96..48804c77d 100644 --- a/lite2/example_test.go +++ b/lite2/example_test.go @@ -63,11 +63,14 @@ func TestExample_Client_AutoUpdate(t *testing.T) { if err != nil { stdlog.Fatal(err) } - defer c.Stop() + defer func() { + c.Stop() + c.Cleanup() + }() time.Sleep(2 * time.Second) - h, err := c.TrustedHeader(3, time.Now()) + h, err := c.TrustedHeader(0, time.Now()) if err != nil { stdlog.Fatal(err) } @@ -122,7 +125,10 @@ func TestExample_Client_ManualUpdate(t *testing.T) { if err != nil { stdlog.Fatal(err) } - defer c.Stop() + defer func() { + c.Stop() + c.Cleanup() + }() _, err = c.VerifyHeaderAtHeight(3, time.Now()) if err != nil { diff --git a/lite2/provider/mock/mock.go b/lite2/provider/mock/mock.go index 3d8cd6876..f895420e5 100644 --- a/lite2/provider/mock/mock.go +++ b/lite2/provider/mock/mock.go @@ -27,6 +27,9 @@ func (p *mock) ChainID() string { } func (p *mock) SignedHeader(height int64) (*types.SignedHeader, error) { + if height == 0 && len(p.headers) > 0 { + return p.headers[int64(len(p.headers))], nil + } if _, ok := p.headers[height]; ok { return p.headers[height], nil } @@ -34,6 +37,9 @@ func (p *mock) SignedHeader(height int64) (*types.SignedHeader, error) { } func (p *mock) ValidatorSet(height int64) (*types.ValidatorSet, error) { + if height == 0 && len(p.vals) > 0 { + return p.vals[int64(len(p.vals))], nil + } if _, ok := p.vals[height]; ok { return p.vals[height], nil }