From a6d20a6660b8c17bf66822c05bb7bc66490d03fe Mon Sep 17 00:00:00 2001 From: Sam Kleinman Date: Wed, 11 Aug 2021 11:37:05 -0400 Subject: [PATCH] pubsub: unsubscribe locking handling (#6816) --- libs/pubsub/pubsub.go | 48 +++++++++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 18 deletions(-) diff --git a/libs/pubsub/pubsub.go b/libs/pubsub/pubsub.go index 54a030fe8..7548470b5 100644 --- a/libs/pubsub/pubsub.go +++ b/libs/pubsub/pubsub.go @@ -231,34 +231,45 @@ func (s *Server) Unsubscribe(ctx context.Context, args UnsubscribeArgs) error { return err } var qs string + if args.Query != nil { qs = args.Query.String() } - s.mtx.RLock() - clientSubscriptions, ok := s.subscriptions[args.Subscriber] - if args.ID != "" { - qs, ok = clientSubscriptions[args.ID] - - if ok && args.Query == nil { - var err error - args.Query, err = query.New(qs) - if err != nil { - return err + clientSubscriptions, err := func() (map[string]string, error) { + s.mtx.RLock() + defer s.mtx.RUnlock() + + clientSubscriptions, ok := s.subscriptions[args.Subscriber] + if args.ID != "" { + qs, ok = clientSubscriptions[args.ID] + + if ok && args.Query == nil { + var err error + args.Query, err = query.New(qs) + if err != nil { + return nil, err + } } + } else if qs != "" { + args.ID, ok = clientSubscriptions[qs] } - } else if qs != "" { - args.ID, ok = clientSubscriptions[qs] - } - s.mtx.RUnlock() - if !ok { - return ErrSubscriptionNotFound + if !ok { + return nil, ErrSubscriptionNotFound + } + + return clientSubscriptions, nil + }() + + if err != nil { + return err } select { case s.cmds <- cmd{op: unsub, clientID: args.Subscriber, query: args.Query, subscription: &Subscription{id: args.ID}}: s.mtx.Lock() + defer s.mtx.Unlock() delete(clientSubscriptions, args.ID) delete(clientSubscriptions, qs) @@ -266,7 +277,6 @@ func (s *Server) Unsubscribe(ctx context.Context, args UnsubscribeArgs) error { if len(clientSubscriptions) == 0 { delete(s.subscriptions, args.Subscriber) } - s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -288,8 +298,10 @@ func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { select { case s.cmds <- cmd{op: unsub, clientID: clientID}: s.mtx.Lock() + defer s.mtx.Unlock() + delete(s.subscriptions, clientID) - s.mtx.Unlock() + return nil case <-ctx.Done(): return ctx.Err()