diff --git a/libs/pubsub/pubsub.go b/libs/pubsub/pubsub.go index cb41f8b38..a39c8c731 100644 --- a/libs/pubsub/pubsub.go +++ b/libs/pubsub/pubsub.go @@ -101,7 +101,7 @@ type Server struct { cmdsCap int mtx sync.RWMutex - subscriptions map[string]map[string]Query // subscriber -> query (string) -> Query + subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct } // Option sets a parameter for the server. @@ -143,7 +143,7 @@ func (ts tagMap) Len() int { // provided, the resulting server's queue is unbuffered. func NewServer(options ...Option) *Server { s := &Server{ - subscriptions: make(map[string]map[string]Query), + subscriptions: make(map[string]map[string]struct{}), } s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) @@ -193,11 +193,9 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: s.mtx.Lock() if _, ok = s.subscriptions[clientID]; !ok { - s.subscriptions[clientID] = make(map[string]Query) + s.subscriptions[clientID] = make(map[string]struct{}) } - // preserve original query - // see Unsubscribe - s.subscriptions[clientID][query.String()] = query + s.subscriptions[clientID][query.String()] = struct{}{} s.mtx.Unlock() return nil case <-ctx.Done(): @@ -211,24 +209,23 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou // returned to the caller if the context is canceled or if subscription does // not exist. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { - var origQuery Query s.mtx.RLock() clientSubscriptions, ok := s.subscriptions[clientID] if ok { - origQuery, ok = clientSubscriptions[query.String()] + _, ok = clientSubscriptions[query.String()] } s.mtx.RUnlock() if !ok { return ErrSubscriptionNotFound } - // original query is used here because we're using pointers as map keys - // ? select { - case s.cmds <- cmd{op: unsub, clientID: clientID, query: origQuery}: + case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: s.mtx.Lock() - // if its the only query left, should we also delete the client? delete(clientSubscriptions, query.String()) + if len(clientSubscriptions) == 0 { + delete(s.subscriptions, clientID) + } s.mtx.Unlock() return nil case <-ctx.Done(): @@ -288,17 +285,27 @@ func (s *Server) OnStop() { // NOTE: not goroutine safe type state struct { - // query -> client -> ch - queries map[Query]map[string]chan<- interface{} - // client -> query -> struct{} - clients map[string]map[Query]struct{} + // query string -> client -> ch + queryToChanMap map[string]map[string]chan<- interface{} + // client -> query string -> struct{} + clientToQueryMap map[string]map[string]struct{} + // query string -> queryPlusRefCount + queries map[string]*queryPlusRefCount +} + +// queryPlusRefCount holds a pointer to a query and reference counter. When +// refCount is zero, query will be removed. +type queryPlusRefCount struct { + q Query + refCount int } // OnStart implements Service.OnStart by starting the server. func (s *Server) OnStart() error { go s.loop(state{ - queries: make(map[Query]map[string]chan<- interface{}), - clients: make(map[string]map[Query]struct{}), + queryToChanMap: make(map[string]map[string]chan<- interface{}), + clientToQueryMap: make(map[string]map[string]struct{}), + queries: make(map[string]*queryPlusRefCount), }) return nil } @@ -319,7 +326,7 @@ loop: state.removeAll(cmd.clientID) } case shutdown: - for clientID := range state.clients { + for clientID := range state.clientToQueryMap { state.removeAll(clientID) } break loop @@ -332,24 +339,34 @@ loop: } func (state *state) add(clientID string, q Query, ch chan<- interface{}) { + qStr := q.String() // initialize clientToChannelMap per query if needed - if _, ok := state.queries[q]; !ok { - state.queries[q] = make(map[string]chan<- interface{}) + if _, ok := state.queryToChanMap[qStr]; !ok { + state.queryToChanMap[qStr] = make(map[string]chan<- interface{}) } // create subscription - state.queries[q][clientID] = ch + state.queryToChanMap[qStr][clientID] = ch + + // initialize queries if needed + if _, ok := state.queries[qStr]; !ok { + state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0} + } + // increment reference counter + state.queries[qStr].refCount++ // add client if needed - if _, ok := state.clients[clientID]; !ok { - state.clients[clientID] = make(map[Query]struct{}) + if _, ok := state.clientToQueryMap[clientID]; !ok { + state.clientToQueryMap[clientID] = make(map[string]struct{}) } - state.clients[clientID][q] = struct{}{} + state.clientToQueryMap[clientID][qStr] = struct{}{} } func (state *state) remove(clientID string, q Query) { - clientToChannelMap, ok := state.queries[q] + qStr := q.String() + + clientToChannelMap, ok := state.queryToChanMap[qStr] if !ok { return } @@ -363,43 +380,58 @@ func (state *state) remove(clientID string, q Query) { // remove the query from client map. // if client is not subscribed to anything else, remove it. - delete(state.clients[clientID], q) - if len(state.clients[clientID]) == 0 { - delete(state.clients, clientID) + delete(state.clientToQueryMap[clientID], qStr) + if len(state.clientToQueryMap[clientID]) == 0 { + delete(state.clientToQueryMap, clientID) } // remove the client from query map. // if query has no other clients subscribed, remove it. - delete(state.queries[q], clientID) - if len(state.queries[q]) == 0 { - delete(state.queries, q) + delete(state.queryToChanMap[qStr], clientID) + if len(state.queryToChanMap[qStr]) == 0 { + delete(state.queryToChanMap, qStr) + } + + // decrease ref counter in queries + state.queries[qStr].refCount-- + // remove the query if nobody else is using it + if state.queries[qStr].refCount == 0 { + delete(state.queries, qStr) } } func (state *state) removeAll(clientID string) { - queryMap, ok := state.clients[clientID] + queryMap, ok := state.clientToQueryMap[clientID] if !ok { return } - for q := range queryMap { - ch := state.queries[q][clientID] + for qStr := range queryMap { + ch := state.queryToChanMap[qStr][clientID] close(ch) // remove the client from query map. // if query has no other clients subscribed, remove it. - delete(state.queries[q], clientID) - if len(state.queries[q]) == 0 { - delete(state.queries, q) + delete(state.queryToChanMap[qStr], clientID) + if len(state.queryToChanMap[qStr]) == 0 { + delete(state.queryToChanMap, qStr) + } + + // decrease ref counter in queries + state.queries[qStr].refCount-- + // remove the query if nobody else is using it + if state.queries[qStr].refCount == 0 { + delete(state.queries, qStr) } } // remove the client. - delete(state.clients, clientID) + delete(state.clientToQueryMap, clientID) } func (state *state) send(msg interface{}, tags TagMap) { - for q, clientToChannelMap := range state.queries { + for qStr, clientToChannelMap := range state.queryToChanMap { + q := state.queries[qStr].q if q.Matches(tags) { for _, ch := range clientToChannelMap { ch <- msg diff --git a/libs/pubsub/pubsub_test.go b/libs/pubsub/pubsub_test.go index 5e9931e40..bb660d9e0 100644 --- a/libs/pubsub/pubsub_test.go +++ b/libs/pubsub/pubsub_test.go @@ -115,6 +115,25 @@ func TestUnsubscribe(t *testing.T) { assert.False(t, ok) } +func TestClientUnsubscribesTwice(t *testing.T) { + s := pubsub.NewServer() + s.SetLogger(log.TestingLogger()) + s.Start() + defer s.Stop() + + ctx := context.Background() + ch := make(chan interface{}) + err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch) + require.NoError(t, err) + err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) + require.NoError(t, err) + + err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) + assert.Equal(t, pubsub.ErrSubscriptionNotFound, err) + err = s.UnsubscribeAll(ctx, clientID) + assert.Equal(t, pubsub.ErrSubscriptionNotFound, err) +} + func TestResubscribe(t *testing.T) { s := pubsub.NewServer() s.SetLogger(log.TestingLogger())