@ -28,6 +28,16 @@ const (
shutdown
)
var (
// ErrSubscriptionNotFound is returned when a client tries to unsubscribe
// from not existing subscription.
ErrSubscriptionNotFound = errors . New ( "subscription not found" )
// ErrAlreadySubscribed is returned when a client tries to subscribe twice or
// more using the same query.
ErrAlreadySubscribed = errors . New ( "already subscribed" )
)
type cmd struct {
op operation
query Query
@ -52,7 +62,7 @@ type Server struct {
cmdsCap int
mtx sync . RWMutex
subscriptions map [ string ] map [ string ] struct { } // subscriber -> query -> struct{}
subscriptions map [ string ] map [ string ] Query // subscriber -> query (string) -> Query
}
// Option sets a parameter for the server.
@ -63,7 +73,7 @@ type Option func(*Server)
// provided, the resulting server's queue is unbuffered.
func NewServer ( options ... Option ) * Server {
s := & Server {
subscriptions : make ( map [ string ] map [ string ] struct { } ) ,
subscriptions : make ( map [ string ] map [ string ] Query ) ,
}
s . BaseService = * cmn . NewBaseService ( nil , "PubSub" , s )
@ -106,16 +116,16 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou
}
s . mtx . RUnlock ( )
if ok {
return errors . New ( "already subscribed" )
return ErrAlreadySubscribed
}
select {
case s . cmds <- cmd { op : sub , clientID : clientID , query : query , ch : out } :
s . mtx . Lock ( )
if _ , ok = s . subscriptions [ clientID ] ; ! ok {
s . subscriptions [ clientID ] = make ( map [ string ] struct { } )
s . subscriptions [ clientID ] = make ( map [ string ] Query )
}
s . subscriptions [ clientID ] [ query . String ( ) ] = struct { } { }
s . subscriptions [ clientID ] [ query . String ( ) ] = query
s . mtx . Unlock ( )
return nil
case <- ctx . Done ( ) :
@ -127,18 +137,20 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou
// returned to the caller if the context is canceled or if subscription does
// not exist.
func ( s * Server ) Unsubscribe ( ctx context . Context , clientID string , query Query ) error {
var origQuery Query
s . mtx . RLock ( )
clientSubscriptions , ok := s . subscriptions [ clientID ]
if ok {
_ , ok = clientSubscriptions [ query . String ( ) ]
origQuery , ok = clientSubscriptions [ query . String ( ) ]
}
s . mtx . RUnlock ( )
if ! ok {
return errors . New ( "subscription not found" )
return ErrSubscriptionNotFound
}
// original query is used here because we're using pointers as map keys
select {
case s . cmds <- cmd { op : unsub , clientID : clientID , query : q uery} :
case s . cmds <- cmd { op : unsub , clientID : clientID , query : origQ uery} :
s . mtx . Lock ( )
delete ( clientSubscriptions , query . String ( ) )
s . mtx . Unlock ( )
@ -155,7 +167,7 @@ func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
_ , ok := s . subscriptions [ clientID ]
s . mtx . RUnlock ( )
if ! ok {
return errors . New ( "subscription not found" )
return ErrSubscriptionNotFound
}
select {