diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 27f15cbeb..54a4b8aed 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -13,6 +13,8 @@ package pubsub import ( "context" + "errors" + "sync" cmn "github.com/tendermint/tmlibs/common" ) @@ -48,6 +50,9 @@ type Server struct { cmds chan cmd cmdsCap int + + mtx sync.RWMutex + subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{} } // Option sets a parameter for the server. @@ -57,7 +62,9 @@ type Option func(*Server) // for a detailed description of how to configure buffering. If no options are // provided, the resulting server's queue is unbuffered. func NewServer(options ...Option) *Server { - s := &Server{} + s := &Server{ + subscriptions: make(map[string]map[string]struct{}), + } s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) for _, option := range options { @@ -83,17 +90,33 @@ func BufferCapacity(cap int) Option { } // BufferCapacity returns capacity of the internal server's queue. -func (s Server) BufferCapacity() int { +func (s *Server) BufferCapacity() int { return s.cmdsCap } // Subscribe creates a subscription for the given client. It accepts a channel -// on which messages matching the given query can be received. If the -// subscription already exists, the old channel will be closed. An error will -// be returned to the caller if the context is canceled. +// on which messages matching the given query can be received. An error will be +// returned to the caller if the context is canceled or if subscription already +// exist for pair clientID and query. func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error { + s.mtx.RLock() + clientSubscriptions, ok := s.subscriptions[clientID] + if ok { + _, ok = clientSubscriptions[query.String()] + } + s.mtx.RUnlock() + if ok { + return errors.New("already subscribed") + } + select { case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: + s.mtx.Lock() + if _, ok = s.subscriptions[clientID]; !ok { + s.subscriptions[clientID] = make(map[string]struct{}) + } + s.subscriptions[clientID][query.String()] = struct{}{} + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -101,10 +124,24 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou } // Unsubscribe removes the subscription on the given query. An error will be -// returned to the caller if the context is canceled. +// returned to the caller if the context is canceled or if subscription does +// not exist. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { + s.mtx.RLock() + clientSubscriptions, ok := s.subscriptions[clientID] + if ok { + _, ok = clientSubscriptions[query.String()] + } + s.mtx.RUnlock() + if !ok { + return errors.New("subscription not found") + } + select { case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: + s.mtx.Lock() + delete(clientSubscriptions, query.String()) + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -112,10 +149,20 @@ func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) } // UnsubscribeAll removes all client subscriptions. An error will be returned -// to the caller if the context is canceled. +// to the caller if the context is canceled or if subscription does not exist. func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { + s.mtx.RLock() + _, ok := s.subscriptions[clientID] + s.mtx.RUnlock() + if !ok { + return errors.New("subscription not found") + } + select { case s.cmds <- cmd{op: unsub, clientID: clientID}: + s.mtx.Lock() + delete(s.subscriptions, clientID) + s.mtx.Unlock() return nil case <-ctx.Done(): return ctx.Err() @@ -187,13 +234,8 @@ loop: func (state *state) add(clientID string, q Query, ch chan<- interface{}) { // add query if needed - if clientToChannelMap, ok := state.queries[q]; !ok { + if _, ok := state.queries[q]; !ok { state.queries[q] = make(map[string]chan<- interface{}) - } else { - // check if already subscribed - if oldCh, ok := clientToChannelMap[clientID]; ok { - close(oldCh) - } } // create subscription diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 7bf7b41f7..84b6aa218 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -86,14 +86,11 @@ func TestClientSubscribesTwice(t *testing.T) { ch2 := make(chan interface{}, 1) err = s.Subscribe(ctx, clientID, q, ch2) - require.NoError(t, err) - - _, ok := <-ch1 - assert.False(t, ok) + require.Error(t, err) err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"}) require.NoError(t, err) - assertReceive(t, "Spider-Man", ch2) + assertReceive(t, "Spider-Man", ch1) } func TestUnsubscribe(t *testing.T) { @@ -117,6 +114,27 @@ func TestUnsubscribe(t *testing.T) { assert.False(t, ok) } +func TestResubscribe(t *testing.T) { + s := pubsub.NewServer() + s.SetLogger(log.TestingLogger()) + s.Start() + defer s.Stop() + + ctx := context.Background() + ch := make(chan interface{}) + err := s.Subscribe(ctx, clientID, query.Empty{}, ch) + require.NoError(t, err) + err = s.Unsubscribe(ctx, clientID, query.Empty{}) + require.NoError(t, err) + ch = make(chan interface{}) + err = s.Subscribe(ctx, clientID, query.Empty{}, ch) + require.NoError(t, err) + + err = s.Publish(ctx, "Cable") + require.NoError(t, err) + assertReceive(t, "Cable", ch) +} + func TestUnsubscribeAll(t *testing.T) { s := pubsub.NewServer() s.SetLogger(log.TestingLogger()) @@ -125,9 +143,9 @@ func TestUnsubscribeAll(t *testing.T) { ctx := context.Background() ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1) - err := s.Subscribe(ctx, clientID, query.Empty{}, ch1) + err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch1) require.NoError(t, err) - err = s.Subscribe(ctx, clientID, query.Empty{}, ch2) + err = s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'"), ch2) require.NoError(t, err) err = s.UnsubscribeAll(ctx, clientID)