diff --git a/CHANGELOG.md b/CHANGELOG.md index 7d3e789a5..00ddd4b4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 0.8.1 (TBD) + +BUG FIXES: + + - [pubsub] fix unsubscribing + ## 0.8.0 (March 22, 2018) BREAKING: diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 28e008ca6..90f6e4ae6 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -28,6 +28,16 @@ const ( shutdown ) +var ( + // ErrSubscriptionNotFound is returned when a client tries to unsubscribe + // from not existing subscription. + ErrSubscriptionNotFound = errors.New("subscription not found") + + // ErrAlreadySubscribed is returned when a client tries to subscribe twice or + // more using the same query. + ErrAlreadySubscribed = errors.New("already subscribed") +) + type cmd struct { op operation query Query @@ -52,7 +62,7 @@ type Server struct { cmdsCap int mtx sync.RWMutex - subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{} + subscriptions map[string]map[string]Query // subscriber -> query (string) -> Query } // Option sets a parameter for the server. @@ -63,7 +73,7 @@ type Option func(*Server) // provided, the resulting server's queue is unbuffered. func NewServer(options ...Option) *Server { s := &Server{ - subscriptions: make(map[string]map[string]struct{}), + subscriptions: make(map[string]map[string]Query), } s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) @@ -106,16 +116,16 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou } s.mtx.RUnlock() if ok { - return errors.New("already subscribed") + return ErrAlreadySubscribed } select { case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: s.mtx.Lock() if _, ok = s.subscriptions[clientID]; !ok { - s.subscriptions[clientID] = make(map[string]struct{}) + s.subscriptions[clientID] = make(map[string]Query) } - s.subscriptions[clientID][query.String()] = struct{}{} + s.subscriptions[clientID][query.String()] = query s.mtx.Unlock() return nil case <-ctx.Done(): @@ -127,18 +137,20 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou // returned to the caller if the context is canceled or if subscription does // not exist. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { + var origQuery Query s.mtx.RLock() clientSubscriptions, ok := s.subscriptions[clientID] if ok { - _, ok = clientSubscriptions[query.String()] + origQuery, ok = clientSubscriptions[query.String()] } s.mtx.RUnlock() if !ok { - return errors.New("subscription not found") + return ErrSubscriptionNotFound } + // original query is used here because we're using pointers as map keys select { - case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: + case s.cmds <- cmd{op: unsub, clientID: clientID, query: origQuery}: s.mtx.Lock() delete(clientSubscriptions, query.String()) s.mtx.Unlock() @@ -155,7 +167,7 @@ func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { _, ok := s.subscriptions[clientID] s.mtx.RUnlock() if !ok { - return errors.New("subscription not found") + return ErrSubscriptionNotFound } select { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 84b6aa218..2af7cea46 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -101,9 +101,9 @@ func TestUnsubscribe(t *testing.T) { ctx := context.Background() ch := make(chan interface{}) - err := s.Subscribe(ctx, clientID, query.Empty{}, ch) + err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch) require.NoError(t, err) - err = s.Unsubscribe(ctx, clientID, query.Empty{}) + err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) err = s.Publish(ctx, "Nick Fury")