You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

344 lines
8.7 KiB

  1. // Package pubsub implements a pub-sub model with a single publisher (Server)
  2. // and multiple subscribers (clients).
  3. //
  4. // Though you can have multiple publishers by sharing a pointer to a server or
  5. // by giving the same channel to each publisher and publishing messages from
  6. // that channel (fan-in).
  7. //
  8. // Clients subscribe for messages, which could be of any type, using a query.
  9. // When some message is published, we match it with all queries. If there is a
  10. // match, this message will be pushed to all clients, subscribed to that query.
  11. // See query subpackage for our implementation.
  12. package pubsub
  13. import (
  14. "context"
  15. "errors"
  16. "sync"
  17. cmn "github.com/tendermint/tmlibs/common"
  18. )
  19. type operation int
  20. const (
  21. sub operation = iota
  22. pub
  23. unsub
  24. shutdown
  25. )
  26. var (
  27. // ErrSubscriptionNotFound is returned when a client tries to unsubscribe
  28. // from not existing subscription.
  29. ErrSubscriptionNotFound = errors.New("subscription not found")
  30. // ErrAlreadySubscribed is returned when a client tries to subscribe twice or
  31. // more using the same query.
  32. ErrAlreadySubscribed = errors.New("already subscribed")
  33. )
  34. type cmd struct {
  35. op operation
  36. query Query
  37. ch chan<- interface{}
  38. clientID string
  39. msg interface{}
  40. tags TagMap
  41. }
  42. // Query defines an interface for a query to be used for subscribing.
  43. type Query interface {
  44. Matches(tags TagMap) bool
  45. String() string
  46. }
  47. // Server allows clients to subscribe/unsubscribe for messages, publishing
  48. // messages with or without tags, and manages internal state.
  49. type Server struct {
  50. cmn.BaseService
  51. cmds chan cmd
  52. cmdsCap int
  53. mtx sync.RWMutex
  54. subscriptions map[string]map[string]Query // subscriber -> query (string) -> Query
  55. }
  56. // Option sets a parameter for the server.
  57. type Option func(*Server)
  58. // TagMap is used to associate tags to a message.
  59. // They can be queried by subscribers to choose messages they will received.
  60. type TagMap interface {
  61. // Get returns the value for a key, or nil if no value is present.
  62. // The ok result indicates whether value was found in the tags.
  63. Get(key string) (value string, ok bool)
  64. // Len returns the number of tags.
  65. Len() int
  66. }
  67. type tagMap map[string]string
  68. var _ TagMap = (*tagMap)(nil)
  69. // NewTagMap constructs a new immutable tag set from a map.
  70. func NewTagMap(data map[string]string) TagMap {
  71. return tagMap(data)
  72. }
  73. // Get returns the value for a key, or nil if no value is present.
  74. // The ok result indicates whether value was found in the tags.
  75. func (ts tagMap) Get(key string) (value string, ok bool) {
  76. value, ok = ts[key]
  77. return
  78. }
  79. // Len returns the number of tags.
  80. func (ts tagMap) Len() int {
  81. return len(ts)
  82. }
  83. // NewServer returns a new server. See the commentary on the Option functions
  84. // for a detailed description of how to configure buffering. If no options are
  85. // provided, the resulting server's queue is unbuffered.
  86. func NewServer(options ...Option) *Server {
  87. s := &Server{
  88. subscriptions: make(map[string]map[string]Query),
  89. }
  90. s.BaseService = *cmn.NewBaseService(nil, "PubSub", s)
  91. for _, option := range options {
  92. option(s)
  93. }
  94. // if BufferCapacity option was not set, the channel is unbuffered
  95. s.cmds = make(chan cmd, s.cmdsCap)
  96. return s
  97. }
  98. // BufferCapacity allows you to specify capacity for the internal server's
  99. // queue. Since the server, given Y subscribers, could only process X messages,
  100. // this option could be used to survive spikes (e.g. high amount of
  101. // transactions during peak hours).
  102. func BufferCapacity(cap int) Option {
  103. return func(s *Server) {
  104. if cap > 0 {
  105. s.cmdsCap = cap
  106. }
  107. }
  108. }
  109. // BufferCapacity returns capacity of the internal server's queue.
  110. func (s *Server) BufferCapacity() int {
  111. return s.cmdsCap
  112. }
  113. // Subscribe creates a subscription for the given client. It accepts a channel
  114. // on which messages matching the given query can be received. An error will be
  115. // returned to the caller if the context is canceled or if subscription already
  116. // exist for pair clientID and query.
  117. func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error {
  118. s.mtx.RLock()
  119. clientSubscriptions, ok := s.subscriptions[clientID]
  120. if ok {
  121. _, ok = clientSubscriptions[query.String()]
  122. }
  123. s.mtx.RUnlock()
  124. if ok {
  125. return ErrAlreadySubscribed
  126. }
  127. select {
  128. case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}:
  129. s.mtx.Lock()
  130. if _, ok = s.subscriptions[clientID]; !ok {
  131. s.subscriptions[clientID] = make(map[string]Query)
  132. }
  133. s.subscriptions[clientID][query.String()] = query
  134. s.mtx.Unlock()
  135. return nil
  136. case <-ctx.Done():
  137. return ctx.Err()
  138. }
  139. }
  140. // Unsubscribe removes the subscription on the given query. An error will be
  141. // returned to the caller if the context is canceled or if subscription does
  142. // not exist.
  143. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
  144. var origQuery Query
  145. s.mtx.RLock()
  146. clientSubscriptions, ok := s.subscriptions[clientID]
  147. if ok {
  148. origQuery, ok = clientSubscriptions[query.String()]
  149. }
  150. s.mtx.RUnlock()
  151. if !ok {
  152. return ErrSubscriptionNotFound
  153. }
  154. // original query is used here because we're using pointers as map keys
  155. select {
  156. case s.cmds <- cmd{op: unsub, clientID: clientID, query: origQuery}:
  157. s.mtx.Lock()
  158. delete(clientSubscriptions, query.String())
  159. s.mtx.Unlock()
  160. return nil
  161. case <-ctx.Done():
  162. return ctx.Err()
  163. }
  164. }
  165. // UnsubscribeAll removes all client subscriptions. An error will be returned
  166. // to the caller if the context is canceled or if subscription does not exist.
  167. func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
  168. s.mtx.RLock()
  169. _, ok := s.subscriptions[clientID]
  170. s.mtx.RUnlock()
  171. if !ok {
  172. return ErrSubscriptionNotFound
  173. }
  174. select {
  175. case s.cmds <- cmd{op: unsub, clientID: clientID}:
  176. s.mtx.Lock()
  177. delete(s.subscriptions, clientID)
  178. s.mtx.Unlock()
  179. return nil
  180. case <-ctx.Done():
  181. return ctx.Err()
  182. }
  183. }
  184. // Publish publishes the given message. An error will be returned to the caller
  185. // if the context is canceled.
  186. func (s *Server) Publish(ctx context.Context, msg interface{}) error {
  187. return s.PublishWithTags(ctx, msg, NewTagMap(make(map[string]string)))
  188. }
  189. // PublishWithTags publishes the given message with the set of tags. The set is
  190. // matched with clients queries. If there is a match, the message is sent to
  191. // the client.
  192. func (s *Server) PublishWithTags(ctx context.Context, msg interface{}, tags TagMap) error {
  193. select {
  194. case s.cmds <- cmd{op: pub, msg: msg, tags: tags}:
  195. return nil
  196. case <-ctx.Done():
  197. return ctx.Err()
  198. }
  199. }
  200. // OnStop implements Service.OnStop by shutting down the server.
  201. func (s *Server) OnStop() {
  202. s.cmds <- cmd{op: shutdown}
  203. }
  204. // NOTE: not goroutine safe
  205. type state struct {
  206. // query -> client -> ch
  207. queries map[Query]map[string]chan<- interface{}
  208. // client -> query -> struct{}
  209. clients map[string]map[Query]struct{}
  210. }
  211. // OnStart implements Service.OnStart by starting the server.
  212. func (s *Server) OnStart() error {
  213. go s.loop(state{
  214. queries: make(map[Query]map[string]chan<- interface{}),
  215. clients: make(map[string]map[Query]struct{}),
  216. })
  217. return nil
  218. }
  219. // OnReset implements Service.OnReset
  220. func (s *Server) OnReset() error {
  221. return nil
  222. }
  223. func (s *Server) loop(state state) {
  224. loop:
  225. for cmd := range s.cmds {
  226. switch cmd.op {
  227. case unsub:
  228. if cmd.query != nil {
  229. state.remove(cmd.clientID, cmd.query)
  230. } else {
  231. state.removeAll(cmd.clientID)
  232. }
  233. case shutdown:
  234. for clientID := range state.clients {
  235. state.removeAll(clientID)
  236. }
  237. break loop
  238. case sub:
  239. state.add(cmd.clientID, cmd.query, cmd.ch)
  240. case pub:
  241. state.send(cmd.msg, cmd.tags)
  242. }
  243. }
  244. }
  245. func (state *state) add(clientID string, q Query, ch chan<- interface{}) {
  246. // add query if needed
  247. if _, ok := state.queries[q]; !ok {
  248. state.queries[q] = make(map[string]chan<- interface{})
  249. }
  250. // create subscription
  251. state.queries[q][clientID] = ch
  252. // add client if needed
  253. if _, ok := state.clients[clientID]; !ok {
  254. state.clients[clientID] = make(map[Query]struct{})
  255. }
  256. state.clients[clientID][q] = struct{}{}
  257. }
  258. func (state *state) remove(clientID string, q Query) {
  259. clientToChannelMap, ok := state.queries[q]
  260. if !ok {
  261. return
  262. }
  263. ch, ok := clientToChannelMap[clientID]
  264. if ok {
  265. close(ch)
  266. delete(state.clients[clientID], q)
  267. // if it not subscribed to anything else, remove the client
  268. if len(state.clients[clientID]) == 0 {
  269. delete(state.clients, clientID)
  270. }
  271. delete(state.queries[q], clientID)
  272. }
  273. }
  274. func (state *state) removeAll(clientID string) {
  275. queryMap, ok := state.clients[clientID]
  276. if !ok {
  277. return
  278. }
  279. for q := range queryMap {
  280. ch := state.queries[q][clientID]
  281. close(ch)
  282. delete(state.queries[q], clientID)
  283. }
  284. delete(state.clients, clientID)
  285. }
  286. func (state *state) send(msg interface{}, tags TagMap) {
  287. for q, clientToChannelMap := range state.queries {
  288. if q.Matches(tags) {
  289. for _, ch := range clientToChannelMap {
  290. ch <- msg
  291. }
  292. }
  293. }
  294. }