|
|
@ -101,7 +101,7 @@ type Server struct { |
|
|
|
cmdsCap int |
|
|
|
|
|
|
|
mtx sync.RWMutex |
|
|
|
subscriptions map[string]map[string]Query // subscriber -> query (string) -> Query
|
|
|
|
subscriptions map[string]map[string]struct{} // subscriber -> query (string) -> empty struct
|
|
|
|
} |
|
|
|
|
|
|
|
// Option sets a parameter for the server.
|
|
|
@ -143,7 +143,7 @@ func (ts tagMap) Len() int { |
|
|
|
// provided, the resulting server's queue is unbuffered.
|
|
|
|
func NewServer(options ...Option) *Server { |
|
|
|
s := &Server{ |
|
|
|
subscriptions: make(map[string]map[string]Query), |
|
|
|
subscriptions: make(map[string]map[string]struct{}), |
|
|
|
} |
|
|
|
s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) |
|
|
|
|
|
|
@ -193,11 +193,9 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou |
|
|
|
case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: |
|
|
|
s.mtx.Lock() |
|
|
|
if _, ok = s.subscriptions[clientID]; !ok { |
|
|
|
s.subscriptions[clientID] = make(map[string]Query) |
|
|
|
s.subscriptions[clientID] = make(map[string]struct{}) |
|
|
|
} |
|
|
|
// preserve original query
|
|
|
|
// see Unsubscribe
|
|
|
|
s.subscriptions[clientID][query.String()] = query |
|
|
|
s.subscriptions[clientID][query.String()] = struct{}{} |
|
|
|
s.mtx.Unlock() |
|
|
|
return nil |
|
|
|
case <-ctx.Done(): |
|
|
@ -211,24 +209,23 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou |
|
|
|
// returned to the caller if the context is canceled or if subscription does
|
|
|
|
// not exist.
|
|
|
|
func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { |
|
|
|
var origQuery Query |
|
|
|
s.mtx.RLock() |
|
|
|
clientSubscriptions, ok := s.subscriptions[clientID] |
|
|
|
if ok { |
|
|
|
origQuery, ok = clientSubscriptions[query.String()] |
|
|
|
_, ok = clientSubscriptions[query.String()] |
|
|
|
} |
|
|
|
s.mtx.RUnlock() |
|
|
|
if !ok { |
|
|
|
return ErrSubscriptionNotFound |
|
|
|
} |
|
|
|
|
|
|
|
// original query is used here because we're using pointers as map keys
|
|
|
|
// ?
|
|
|
|
select { |
|
|
|
case s.cmds <- cmd{op: unsub, clientID: clientID, query: origQuery}: |
|
|
|
case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: |
|
|
|
s.mtx.Lock() |
|
|
|
// if its the only query left, should we also delete the client?
|
|
|
|
delete(clientSubscriptions, query.String()) |
|
|
|
if len(clientSubscriptions) == 0 { |
|
|
|
delete(s.subscriptions, clientID) |
|
|
|
} |
|
|
|
s.mtx.Unlock() |
|
|
|
return nil |
|
|
|
case <-ctx.Done(): |
|
|
@ -288,17 +285,27 @@ func (s *Server) OnStop() { |
|
|
|
|
|
|
|
// NOTE: not goroutine safe
|
|
|
|
type state struct { |
|
|
|
// query -> client -> ch
|
|
|
|
queries map[Query]map[string]chan<- interface{} |
|
|
|
// client -> query -> struct{}
|
|
|
|
clients map[string]map[Query]struct{} |
|
|
|
// query string -> client -> ch
|
|
|
|
queryToChanMap map[string]map[string]chan<- interface{} |
|
|
|
// client -> query string -> struct{}
|
|
|
|
clientToQueryMap map[string]map[string]struct{} |
|
|
|
// query string -> queryPlusRefCount
|
|
|
|
queries map[string]*queryPlusRefCount |
|
|
|
} |
|
|
|
|
|
|
|
// queryPlusRefCount holds a pointer to a query and reference counter. When
|
|
|
|
// refCount is zero, query will be removed.
|
|
|
|
type queryPlusRefCount struct { |
|
|
|
q Query |
|
|
|
refCount int |
|
|
|
} |
|
|
|
|
|
|
|
// OnStart implements Service.OnStart by starting the server.
|
|
|
|
func (s *Server) OnStart() error { |
|
|
|
go s.loop(state{ |
|
|
|
queries: make(map[Query]map[string]chan<- interface{}), |
|
|
|
clients: make(map[string]map[Query]struct{}), |
|
|
|
queryToChanMap: make(map[string]map[string]chan<- interface{}), |
|
|
|
clientToQueryMap: make(map[string]map[string]struct{}), |
|
|
|
queries: make(map[string]*queryPlusRefCount), |
|
|
|
}) |
|
|
|
return nil |
|
|
|
} |
|
|
@ -319,7 +326,7 @@ loop: |
|
|
|
state.removeAll(cmd.clientID) |
|
|
|
} |
|
|
|
case shutdown: |
|
|
|
for clientID := range state.clients { |
|
|
|
for clientID := range state.clientToQueryMap { |
|
|
|
state.removeAll(clientID) |
|
|
|
} |
|
|
|
break loop |
|
|
@ -332,24 +339,34 @@ loop: |
|
|
|
} |
|
|
|
|
|
|
|
func (state *state) add(clientID string, q Query, ch chan<- interface{}) { |
|
|
|
qStr := q.String() |
|
|
|
|
|
|
|
// initialize clientToChannelMap per query if needed
|
|
|
|
if _, ok := state.queries[q]; !ok { |
|
|
|
state.queries[q] = make(map[string]chan<- interface{}) |
|
|
|
if _, ok := state.queryToChanMap[qStr]; !ok { |
|
|
|
state.queryToChanMap[qStr] = make(map[string]chan<- interface{}) |
|
|
|
} |
|
|
|
|
|
|
|
// create subscription
|
|
|
|
state.queries[q][clientID] = ch |
|
|
|
state.queryToChanMap[qStr][clientID] = ch |
|
|
|
|
|
|
|
// initialize queries if needed
|
|
|
|
if _, ok := state.queries[qStr]; !ok { |
|
|
|
state.queries[qStr] = &queryPlusRefCount{q: q, refCount: 0} |
|
|
|
} |
|
|
|
// increment reference counter
|
|
|
|
state.queries[qStr].refCount++ |
|
|
|
|
|
|
|
// add client if needed
|
|
|
|
if _, ok := state.clients[clientID]; !ok { |
|
|
|
state.clients[clientID] = make(map[Query]struct{}) |
|
|
|
if _, ok := state.clientToQueryMap[clientID]; !ok { |
|
|
|
state.clientToQueryMap[clientID] = make(map[string]struct{}) |
|
|
|
} |
|
|
|
state.clients[clientID][q] = struct{}{} |
|
|
|
state.clientToQueryMap[clientID][qStr] = struct{}{} |
|
|
|
} |
|
|
|
|
|
|
|
func (state *state) remove(clientID string, q Query) { |
|
|
|
clientToChannelMap, ok := state.queries[q] |
|
|
|
qStr := q.String() |
|
|
|
|
|
|
|
clientToChannelMap, ok := state.queryToChanMap[qStr] |
|
|
|
if !ok { |
|
|
|
return |
|
|
|
} |
|
|
@ -363,43 +380,58 @@ func (state *state) remove(clientID string, q Query) { |
|
|
|
|
|
|
|
// remove the query from client map.
|
|
|
|
// if client is not subscribed to anything else, remove it.
|
|
|
|
delete(state.clients[clientID], q) |
|
|
|
if len(state.clients[clientID]) == 0 { |
|
|
|
delete(state.clients, clientID) |
|
|
|
delete(state.clientToQueryMap[clientID], qStr) |
|
|
|
if len(state.clientToQueryMap[clientID]) == 0 { |
|
|
|
delete(state.clientToQueryMap, clientID) |
|
|
|
} |
|
|
|
|
|
|
|
// remove the client from query map.
|
|
|
|
// if query has no other clients subscribed, remove it.
|
|
|
|
delete(state.queries[q], clientID) |
|
|
|
if len(state.queries[q]) == 0 { |
|
|
|
delete(state.queries, q) |
|
|
|
delete(state.queryToChanMap[qStr], clientID) |
|
|
|
if len(state.queryToChanMap[qStr]) == 0 { |
|
|
|
delete(state.queryToChanMap, qStr) |
|
|
|
} |
|
|
|
|
|
|
|
// decrease ref counter in queries
|
|
|
|
state.queries[qStr].refCount-- |
|
|
|
// remove the query if nobody else is using it
|
|
|
|
if state.queries[qStr].refCount == 0 { |
|
|
|
delete(state.queries, qStr) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (state *state) removeAll(clientID string) { |
|
|
|
queryMap, ok := state.clients[clientID] |
|
|
|
queryMap, ok := state.clientToQueryMap[clientID] |
|
|
|
if !ok { |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
for q := range queryMap { |
|
|
|
ch := state.queries[q][clientID] |
|
|
|
for qStr := range queryMap { |
|
|
|
ch := state.queryToChanMap[qStr][clientID] |
|
|
|
close(ch) |
|
|
|
|
|
|
|
// remove the client from query map.
|
|
|
|
// if query has no other clients subscribed, remove it.
|
|
|
|
delete(state.queries[q], clientID) |
|
|
|
if len(state.queries[q]) == 0 { |
|
|
|
delete(state.queries, q) |
|
|
|
delete(state.queryToChanMap[qStr], clientID) |
|
|
|
if len(state.queryToChanMap[qStr]) == 0 { |
|
|
|
delete(state.queryToChanMap, qStr) |
|
|
|
} |
|
|
|
|
|
|
|
// decrease ref counter in queries
|
|
|
|
state.queries[qStr].refCount-- |
|
|
|
// remove the query if nobody else is using it
|
|
|
|
if state.queries[qStr].refCount == 0 { |
|
|
|
delete(state.queries, qStr) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// remove the client.
|
|
|
|
delete(state.clients, clientID) |
|
|
|
delete(state.clientToQueryMap, clientID) |
|
|
|
} |
|
|
|
|
|
|
|
func (state *state) send(msg interface{}, tags TagMap) { |
|
|
|
for q, clientToChannelMap := range state.queries { |
|
|
|
for qStr, clientToChannelMap := range state.queryToChanMap { |
|
|
|
q := state.queries[qStr].q |
|
|
|
if q.Matches(tags) { |
|
|
|
for _, ch := range clientToChannelMap { |
|
|
|
ch <- msg |
|
|
|