Browse Source

libs/service: regularize Stop semantics and concurrency primitives (#7809)

pull/7820/head
Sam Kleinman 2 years ago
committed by GitHub
parent
commit
824960c565
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 224 additions and 114 deletions
  1. +9
    -5
      cmd/tendermint/commands/rollback_test.go
  2. +1
    -1
      cmd/tendermint/commands/run_node.go
  3. +6
    -8
      internal/blocksync/pool_test.go
  4. +24
    -8
      internal/p2p/transport_mconn.go
  5. +0
    -3
      internal/p2p/transport_mconn_test.go
  6. +69
    -64
      libs/service/service.go
  7. +110
    -20
      libs/service/service_test.go
  8. +0
    -1
      node/node.go
  9. +5
    -4
      node/node_test.go

+ 9
- 5
cmd/tendermint/commands/rollback_test.go View File

@ -22,7 +22,9 @@ func TestRollbackIntegration(t *testing.T) {
cfg, err := rpctest.CreateConfig(t.Name()) cfg, err := rpctest.CreateConfig(t.Name())
require.NoError(t, err) require.NoError(t, err)
cfg.BaseConfig.DBBackend = "goleveldb" cfg.BaseConfig.DBBackend = "goleveldb"
app, err := e2e.NewApplication(e2e.DefaultConfig(dir)) app, err := e2e.NewApplication(e2e.DefaultConfig(dir))
require.NoError(t, err)
t.Run("First run", func(t *testing.T) { t.Run("First run", func(t *testing.T) {
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
@ -30,27 +32,29 @@ func TestRollbackIntegration(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
node, _, err := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) node, _, err := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout)
require.NoError(t, err) require.NoError(t, err)
require.True(t, node.IsRunning())
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
cancel() cancel()
node.Wait() node.Wait()
require.False(t, node.IsRunning()) require.False(t, node.IsRunning())
}) })
t.Run("Rollback", func(t *testing.T) { t.Run("Rollback", func(t *testing.T) {
time.Sleep(time.Second)
require.NoError(t, app.Rollback()) require.NoError(t, app.Rollback())
height, _, err = commands.RollbackState(cfg) height, _, err = commands.RollbackState(cfg)
require.NoError(t, err)
require.NoError(t, err, "%d", height)
}) })
t.Run("Restart", func(t *testing.T) { t.Run("Restart", func(t *testing.T) {
require.True(t, height > 0, "%d", height)
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
node2, _, err2 := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout) node2, _, err2 := rpctest.StartTendermint(ctx, cfg, app, rpctest.SuppressStdout)
require.NoError(t, err2) require.NoError(t, err2)
logger := log.NewTestingLogger(t)
logger := log.NewNopLogger()
client, err := local.New(logger, node2.(local.NodeService)) client, err := local.New(logger, node2.(local.NodeService))
require.NoError(t, err) require.NoError(t, err)


+ 1
- 1
cmd/tendermint/commands/run_node.go View File

@ -117,7 +117,7 @@ func NewRunNodeCmd(nodeProvider cfg.ServiceProvider, conf *cfg.Config, logger lo
return fmt.Errorf("failed to start node: %w", err) return fmt.Errorf("failed to start node: %w", err)
} }
logger.Info("started node", "node", n.String())
logger.Info("started node", "chain", conf.ChainID())
<-ctx.Done() <-ctx.Done()
return nil return nil


+ 6
- 8
internal/blocksync/pool_test.go View File

@ -125,7 +125,6 @@ func TestBlockPoolBasic(t *testing.T) {
case err := <-errorsCh: case err := <-errorsCh:
t.Error(err) t.Error(err)
case request := <-requestsCh: case request := <-requestsCh:
t.Logf("Pulled new BlockRequest %v", request)
if request.Height == 300 { if request.Height == 300 {
return // Done! return // Done!
} }
@ -139,21 +138,19 @@ func TestBlockPoolTimeout(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
logger := log.TestingLogger()
start := int64(42) start := int64(42)
peers := makePeers(10, start+1, 1000) peers := makePeers(10, start+1, 1000)
errorsCh := make(chan peerError, 1000) errorsCh := make(chan peerError, 1000)
requestsCh := make(chan BlockRequest, 1000) requestsCh := make(chan BlockRequest, 1000)
pool := NewBlockPool(log.TestingLogger(), start, requestsCh, errorsCh)
pool := NewBlockPool(logger, start, requestsCh, errorsCh)
err := pool.Start(ctx) err := pool.Start(ctx)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
t.Cleanup(func() { cancel(); pool.Wait() }) t.Cleanup(func() { cancel(); pool.Wait() })
for _, peer := range peers {
t.Logf("Peer %v", peer.id)
}
// Introduce each peer. // Introduce each peer.
go func() { go func() {
for _, peer := range peers { for _, peer := range peers {
@ -182,7 +179,6 @@ func TestBlockPoolTimeout(t *testing.T) {
for { for {
select { select {
case err := <-errorsCh: case err := <-errorsCh:
t.Log(err)
// consider error to be always timeout here // consider error to be always timeout here
if _, ok := timedOut[err.peerID]; !ok { if _, ok := timedOut[err.peerID]; !ok {
counter++ counter++
@ -191,7 +187,9 @@ func TestBlockPoolTimeout(t *testing.T) {
} }
} }
case request := <-requestsCh: case request := <-requestsCh:
t.Logf("Pulled new BlockRequest %+v", request)
logger.Debug("received request",
"counter", counter,
"request", request)
} }
} }
} }


+ 24
- 8
internal/p2p/transport_mconn.go View File

@ -138,19 +138,35 @@ func (m *MConnTransport) Accept(ctx context.Context) (Connection, error) {
return nil, errors.New("transport is not listening") return nil, errors.New("transport is not listening")
} }
tcpConn, err := m.listener.Accept()
if err != nil {
conCh := make(chan net.Conn)
errCh := make(chan error)
go func() {
tcpConn, err := m.listener.Accept()
if err != nil {
select {
case errCh <- err:
case <-ctx.Done():
}
}
select { select {
case conCh <- tcpConn:
case <-ctx.Done(): case <-ctx.Done():
return nil, io.EOF
case <-m.doneCh:
return nil, io.EOF
default:
return nil, err
} }
}()
select {
case <-ctx.Done():
m.listener.Close()
return nil, io.EOF
case <-m.doneCh:
m.listener.Close()
return nil, io.EOF
case err := <-errCh:
return nil, err
case tcpConn := <-conCh:
return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil
} }
return newMConnConnection(m.logger, tcpConn, m.mConnConfig, m.channelDescs), nil
} }
// Dial implements Transport. // Dial implements Transport.


+ 0
- 3
internal/p2p/transport_mconn_test.go View File

@ -154,9 +154,6 @@ func TestMConnTransport_Listen(t *testing.T) {
t.Run(tc.endpoint.String(), func(t *testing.T) { t.Run(tc.endpoint.String(), func(t *testing.T) {
t.Cleanup(leaktest.Check(t)) t.Cleanup(leaktest.Check(t))
ctx, cancel = context.WithCancel(ctx)
defer cancel()
transport := p2p.NewMConnTransport( transport := p2p.NewMConnTransport(
log.TestingLogger(), log.TestingLogger(),
conn.DefaultMConnConfig(), conn.DefaultMConnConfig(),


+ 69
- 64
libs/service/service.go View File

@ -3,7 +3,7 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"sync/atomic"
"sync"
"github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/log"
) )
@ -30,9 +30,6 @@ type Service interface {
// Return true if the service is running // Return true if the service is running
IsRunning() bool IsRunning() bool
// String representation of the service
String() string
// Wait blocks until the service is stopped. // Wait blocks until the service is stopped.
Wait() Wait()
} }
@ -40,8 +37,6 @@ type Service interface {
// Implementation describes the implementation that the // Implementation describes the implementation that the
// BaseService implementation wraps. // BaseService implementation wraps.
type Implementation interface { type Implementation interface {
Service
// Called by the Services Start Method // Called by the Services Start Method
OnStart(context.Context) error OnStart(context.Context) error
@ -57,12 +52,7 @@ Users can override the OnStart/OnStop methods. In the absence of errors, these
methods are guaranteed to be called at most once. If OnStart returns an error, methods are guaranteed to be called at most once. If OnStart returns an error,
service won't be marked as started, so the user can call Start again. service won't be marked as started, so the user can call Start again.
A call to Reset will panic, unless OnReset is overwritten, allowing
OnStart/OnStop to be called again.
The caller must ensure that Start and Stop are not called concurrently.
It is ok to call Stop without calling Start first.
It is safe, but an error, to call Stop without calling Start first.
Typical usage: Typical usage:
@ -80,23 +70,21 @@ Typical usage:
} }
func (fs *FooService) OnStart(ctx context.Context) error { func (fs *FooService) OnStart(ctx context.Context) error {
fs.BaseService.OnStart() // Always call the overridden method.
// initialize private fields // initialize private fields
// start subroutines, etc. // start subroutines, etc.
} }
func (fs *FooService) OnStop() error { func (fs *FooService) OnStop() error {
fs.BaseService.OnStop() // Always call the overridden method.
// close/destroy private fields // close/destroy private fields
// stop subroutines, etc. // stop subroutines, etc.
} }
*/ */
type BaseService struct { type BaseService struct {
logger log.Logger
name string
started uint32 // atomic
stopped uint32 // atomic
quit chan struct{}
logger log.Logger
name string
mtx sync.Mutex
quit <-chan (struct{})
cancel context.CancelFunc
// The "subclass" of BaseService // The "subclass" of BaseService
impl Implementation impl Implementation
@ -107,7 +95,6 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe
return &BaseService{ return &BaseService{
logger: logger, logger: logger,
name: name, name: name,
quit: make(chan struct{}),
impl: impl, impl: impl,
} }
} }
@ -116,83 +103,101 @@ func NewBaseService(logger log.Logger, name string, impl Implementation) *BaseSe
// returned if the service is already running or stopped. To restart a // returned if the service is already running or stopped. To restart a
// stopped service, call Reset. // stopped service, call Reset.
func (bs *BaseService) Start(ctx context.Context) error { func (bs *BaseService) Start(ctx context.Context) error {
if atomic.CompareAndSwapUint32(&bs.started, 0, 1) {
if atomic.LoadUint32(&bs.stopped) == 1 {
bs.logger.Error("not starting service; already stopped", "service", bs.name, "impl", bs.impl.String())
atomic.StoreUint32(&bs.started, 0)
return ErrAlreadyStopped
}
bs.mtx.Lock()
defer bs.mtx.Unlock()
bs.logger.Info("starting service", "service", bs.name, "impl", bs.impl.String())
if bs.quit != nil {
return ErrAlreadyStarted
}
select {
case <-bs.quit:
return ErrAlreadyStopped
default:
bs.logger.Info("starting service", "service", bs.name, "impl", bs.name)
if err := bs.impl.OnStart(ctx); err != nil { if err := bs.impl.OnStart(ctx); err != nil {
// revert flag
atomic.StoreUint32(&bs.started, 0)
return err return err
} }
// we need a separate context to ensure that we start
// a thread that will get cleaned up and that the
// Stop/Wait functions work as expected.
srvCtx, cancel := context.WithCancel(context.Background())
bs.cancel = cancel
bs.quit = srvCtx.Done()
go func(ctx context.Context) { go func(ctx context.Context) {
select { select {
case <-bs.quit:
// someone else explicitly called stop
// and then we shouldn't.
case <-srvCtx.Done():
// this means stop was called manually
return return
case <-ctx.Done(): case <-ctx.Done():
// if nothing is running, no need to
// shut down again.
if !bs.impl.IsRunning() {
return
}
// the context was cancel and we
// should stop.
if err := bs.Stop(); err != nil {
bs.logger.Error("stopped service",
"err", err.Error(),
"service", bs.name,
"impl", bs.impl.String())
}
bs.logger.Info("stopped service",
"service", bs.name,
"impl", bs.impl.String())
_ = bs.Stop()
} }
bs.logger.Info("stopped service",
"service", bs.name)
}(ctx) }(ctx)
return nil return nil
} }
return ErrAlreadyStarted
} }
// Stop implements Service by calling OnStop (if defined) and closing quit // Stop implements Service by calling OnStop (if defined) and closing quit
// channel. An error will be returned if the service is already stopped. // channel. An error will be returned if the service is already stopped.
func (bs *BaseService) Stop() error { func (bs *BaseService) Stop() error {
if atomic.CompareAndSwapUint32(&bs.stopped, 0, 1) {
if atomic.LoadUint32(&bs.started) == 0 {
bs.logger.Error("not stopping service; not started yet", "service", bs.name, "impl", bs.impl.String())
atomic.StoreUint32(&bs.stopped, 0)
return ErrNotStarted
}
bs.mtx.Lock()
defer bs.mtx.Unlock()
if bs.quit == nil {
return ErrNotStarted
}
bs.logger.Info("stopping service", "service", bs.name, "impl", bs.impl.String())
select {
case <-bs.quit:
return ErrAlreadyStopped
default:
bs.logger.Info("stopping service", "service", bs.name)
bs.impl.OnStop() bs.impl.OnStop()
close(bs.quit)
bs.cancel()
return nil return nil
} }
return ErrAlreadyStopped
} }
// IsRunning implements Service by returning true or false depending on the // IsRunning implements Service by returning true or false depending on the
// service's state. // service's state.
func (bs *BaseService) IsRunning() bool { func (bs *BaseService) IsRunning() bool {
return atomic.LoadUint32(&bs.started) == 1 && atomic.LoadUint32(&bs.stopped) == 0
bs.mtx.Lock()
defer bs.mtx.Unlock()
if bs.quit == nil {
return false
}
select {
case <-bs.quit:
return false
default:
return true
}
}
func (bs *BaseService) getWait() <-chan struct{} {
bs.mtx.Lock()
defer bs.mtx.Unlock()
if bs.quit == nil {
out := make(chan struct{})
close(out)
return out
}
return bs.quit
} }
// Wait blocks until the service is stopped. // Wait blocks until the service is stopped.
func (bs *BaseService) Wait() { <-bs.quit }
func (bs *BaseService) Wait() { <-bs.getWait() }
// String implements Service by returning a string representation of the service. // String implements Service by returning a string representation of the service.
func (bs *BaseService) String() string { return bs.name } func (bs *BaseService) String() string { return bs.name }

+ 110
- 20
libs/service/service_test.go View File

@ -2,45 +2,135 @@ package service
import ( import (
"context" "context"
"sync"
"testing" "testing"
"time" "time"
"github.com/fortytw2/leaktest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tendermint/tendermint/libs/log" "github.com/tendermint/tendermint/libs/log"
) )
type testService struct { type testService struct {
started bool
stopped bool
multiStopped bool
mu sync.Mutex
BaseService BaseService
} }
func (testService) OnStop() {}
func (testService) OnStart(context.Context) error {
func (t *testService) OnStop() {
t.mu.Lock()
defer t.mu.Unlock()
if t.stopped == true {
t.multiStopped = true
}
t.stopped = true
}
func (t *testService) OnStart(context.Context) error {
t.mu.Lock()
defer t.mu.Unlock()
t.started = true
return nil return nil
} }
func TestBaseServiceWait(t *testing.T) {
func (t *testService) isStarted() bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.started
}
func (t *testService) isStopped() bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.stopped
}
func (t *testService) isMultiStopped() bool {
t.mu.Lock()
defer t.mu.Unlock()
return t.multiStopped
}
func TestBaseService(t *testing.T) {
t.Cleanup(leaktest.Check(t))
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
logger := log.NewTestingLogger(t)
logger := log.NewNopLogger()
ts := &testService{}
ts.BaseService = *NewBaseService(logger, "TestService", ts)
err := ts.Start(ctx)
require.NoError(t, err)
t.Run("Wait", func(t *testing.T) {
wctx, wcancel := context.WithCancel(ctx)
defer wcancel()
ts := &testService{}
ts.BaseService = *NewBaseService(logger, t.Name(), ts)
err := ts.Start(wctx)
require.NoError(t, err)
require.True(t, ts.isStarted())
waitFinished := make(chan struct{})
go func() {
ts.Wait()
waitFinished <- struct{}{}
}()
waitFinished := make(chan struct{})
wcancel()
go func() {
ts.Wait()
close(waitFinished)
}()
go cancel()
select {
case <-waitFinished:
assert.True(t, ts.isStopped(), "failed to stop")
assert.False(t, ts.IsRunning(), "is not running")
case <-time.After(100 * time.Millisecond):
t.Fatal("expected Wait() to finish within 100 ms.")
}
})
t.Run("ManualStop", func(t *testing.T) {
ts := &testService{}
ts.BaseService = *NewBaseService(logger, t.Name(), ts)
require.False(t, ts.IsRunning())
require.False(t, ts.isStarted())
require.NoError(t, ts.Start(ctx))
require.True(t, ts.isStarted())
require.NoError(t, ts.Stop())
require.True(t, ts.isStopped())
require.False(t, ts.IsRunning())
})
t.Run("MultiStop", func(t *testing.T) {
t.Run("SingleThreaded", func(t *testing.T) {
ts := &testService{}
ts.BaseService = *NewBaseService(logger, t.Name(), ts)
require.NoError(t, ts.Start(ctx))
require.True(t, ts.isStarted())
require.NoError(t, ts.Stop())
require.True(t, ts.isStopped())
require.False(t, ts.isMultiStopped())
require.Error(t, ts.Stop())
require.False(t, ts.isMultiStopped())
})
t.Run("MultiThreaded", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts := &testService{}
ts.BaseService = *NewBaseService(logger, t.Name(), ts)
require.NoError(t, ts.Start(ctx))
require.True(t, ts.isStarted())
go func() { _ = ts.Stop() }()
go cancel()
ts.Wait()
require.True(t, ts.isStopped())
require.False(t, ts.isMultiStopped())
})
})
select {
case <-waitFinished:
// all good
case <-time.After(100 * time.Millisecond):
t.Fatal("expected Wait() to finish within 100 ms.")
}
} }

+ 0
- 1
node/node.go View File

@ -550,7 +550,6 @@ func (n *nodeImpl) OnStart(ctx context.Context) error {
// OnStop stops the Node. It implements service.Service. // OnStop stops the Node. It implements service.Service.
func (n *nodeImpl) OnStop() { func (n *nodeImpl) OnStop() {
n.logger.Info("Stopping Node") n.logger.Info("Stopping Node")
for _, es := range n.eventSinks { for _, es := range n.eventSinks {
if err := es.Stop(); err != nil { if err := es.Stop(); err != nil {
n.logger.Error("failed to stop event sink", "err", err) n.logger.Error("failed to stop event sink", "err", err)


+ 5
- 4
node/node_test.go View File

@ -55,11 +55,10 @@ func TestNodeStartStop(t *testing.T) {
n, ok := ns.(*nodeImpl) n, ok := ns.(*nodeImpl)
require.True(t, ok) require.True(t, ok)
t.Cleanup(func() { t.Cleanup(func() {
if n.IsRunning() {
bcancel()
n.Wait()
}
bcancel()
n.Wait()
}) })
t.Cleanup(leaktest.CheckTimeout(t, time.Second))
require.NoError(t, n.Start(ctx)) require.NoError(t, n.Start(ctx))
// wait for the node to produce a block // wait for the node to produce a block
@ -98,6 +97,7 @@ func getTestNode(ctx context.Context, t *testing.T, conf *config.Config, logger
ns.Wait() ns.Wait()
} }
}) })
t.Cleanup(leaktest.CheckTimeout(t, time.Second)) t.Cleanup(leaktest.CheckTimeout(t, time.Second))
return n return n
@ -568,6 +568,7 @@ func TestNodeNewSeedNode(t *testing.T) {
logger, logger,
) )
t.Cleanup(ns.Wait) t.Cleanup(ns.Wait)
t.Cleanup(leaktest.CheckTimeout(t, time.Second))
require.NoError(t, err) require.NoError(t, err)
n, ok := ns.(*seedNodeImpl) n, ok := ns.(*seedNodeImpl)


Loading…
Cancel
Save