From f6569b5dcd1a3ad7fc9dfd48bb6ac3e2045a0d5a Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Thu, 17 Feb 2022 13:56:11 -0500 Subject: [PATCH] abci/client: remove waitgroup for requests (#7842) * abci/client: remove awkward waitgroup * elide done --- abci/client/client.go | 36 ++++++++++++++---------------------- abci/client/grpc_client.go | 1 - abci/client/socket_client.go | 15 ++++----------- 3 files changed, 18 insertions(+), 34 deletions(-) diff --git a/abci/client/client.go b/abci/client/client.go index e535aa028..5dbaeaf1f 100644 --- a/abci/client/client.go +++ b/abci/client/client.go @@ -74,22 +74,19 @@ type Callback func(*types.Request, *types.Response) type ReqRes struct { *types.Request - *sync.WaitGroup *types.Response // Not set atomically, so be sure to use WaitGroup. - mtx sync.Mutex - done bool // Gets set to true once *after* WaitGroup.Done(). - cb func(*types.Response) // A single callback that may be set. + mtx sync.Mutex + signal chan struct{} + cb func(*types.Response) // A single callback that may be set. } func NewReqRes(req *types.Request) *ReqRes { return &ReqRes{ - Request: req, - WaitGroup: waitGroup1(), - Response: nil, - - done: false, - cb: nil, + Request: req, + Response: nil, + signal: make(chan struct{}), + cb: nil, } } @@ -99,14 +96,14 @@ func NewReqRes(req *types.Request) *ReqRes { func (r *ReqRes) SetCallback(cb func(res *types.Response)) { r.mtx.Lock() - if r.done { + select { + case <-r.signal: r.mtx.Unlock() cb(r.Response) - return + default: + r.cb = cb + r.mtx.Unlock() } - - r.cb = cb - r.mtx.Unlock() } // InvokeCallback invokes a thread-safe execution of the configured callback @@ -135,12 +132,7 @@ func (r *ReqRes) GetCallback() func(*types.Response) { // SetDone marks the ReqRes object as done. func (r *ReqRes) SetDone() { r.mtx.Lock() - r.done = true - r.mtx.Unlock() -} + defer r.mtx.Unlock() -func waitGroup1() (wg *sync.WaitGroup) { - wg = &sync.WaitGroup{} - wg.Add(1) - return + close(r.signal) } diff --git a/abci/client/grpc_client.go b/abci/client/grpc_client.go index 1cdae1abb..936bb0f73 100644 --- a/abci/client/grpc_client.go +++ b/abci/client/grpc_client.go @@ -77,7 +77,6 @@ func (cli *grpcClient) OnStart(ctx context.Context) error { defer cli.mtx.Unlock() reqres.SetDone() - reqres.Done() // Notify client listener if set if cli.resCb != nil { diff --git a/abci/client/socket_client.go b/abci/client/socket_client.go index 074dd1d00..674bcdf7a 100644 --- a/abci/client/socket_client.go +++ b/abci/client/socket_client.go @@ -197,7 +197,7 @@ func (cli *socketClient) didRecvResponse(res *types.Response) error { } reqres.Response = res - reqres.Done() // release waiters + reqres.SetDone() // release waiters cli.reqSent.Remove(next) // pop first item from linked list // Notify client listener if set (global callback). @@ -236,15 +236,8 @@ func (cli *socketClient) Flush(ctx context.Context) error { return err } - gotResp := make(chan struct{}) - go func() { - // NOTE: if we don't flush the queue, its possible to get stuck here - reqRes.Wait() - close(gotResp) - }() - select { - case <-gotResp: + case <-reqRes.signal: return cli.Error() case <-ctx.Done(): return ctx.Err() @@ -487,7 +480,7 @@ func (cli *socketClient) drainQueue(ctx context.Context) { // mark all in-flight messages as resolved (they will get cli.Error()) for req := cli.reqSent.Front(); req != nil; req = req.Next() { reqres := req.Value.(*ReqRes) - reqres.Done() + reqres.SetDone() } // Mark all queued messages as resolved. @@ -500,7 +493,7 @@ func (cli *socketClient) drainQueue(ctx context.Context) { case <-ctx.Done(): return case reqres := <-cli.reqQueue: - reqres.Done() + reqres.SetDone() default: return }