Browse Source

Merge pull request #118 from tendermint/develop

v0.6.0
pull/1842/head
Ethan Buchman 7 years ago
committed by GitHub
parent
commit
91b4b534ad
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 740 additions and 273 deletions
  1. +13
    -0
      CHANGELOG.md
  2. +7
    -20
      cli/setup.go
  3. +5
    -11
      cli/setup_test.go
  4. +152
    -94
      clist/clist.go
  5. +3
    -4
      common/bit_array.go
  6. +100
    -21
      common/random.go
  7. +120
    -0
      common/random_test.go
  8. +184
    -42
      common/repeat_timer.go
  9. +68
    -54
      common/repeat_timer_test.go
  10. +56
    -13
      pubsub/pubsub.go
  11. +25
    -7
      pubsub/pubsub_test.go
  12. +6
    -6
      test.sh
  13. +1
    -1
      version/version.go

+ 13
- 0
CHANGELOG.md View File

@ -1,5 +1,18 @@
# Changelog # Changelog
## 0.6.0 (December 29, 2017)
BREAKING:
- [cli] remove --root
- [pubsub] add String() method to Query interface
IMPROVEMENTS:
- [common] use a thread-safe and well seeded non-crypto rng
BUG FIXES
- [clist] fix misuse of wait group
- [common] introduce Ticker interface and logicalTicker for better testing of timers
## 0.5.0 (December 5, 2017) ## 0.5.0 (December 5, 2017)
BREAKING: BREAKING:


+ 7
- 20
cli/setup.go View File

@ -14,7 +14,6 @@ import (
) )
const ( const (
RootFlag = "root"
HomeFlag = "home" HomeFlag = "home"
TraceFlag = "trace" TraceFlag = "trace"
OutputFlag = "output" OutputFlag = "output"
@ -28,14 +27,9 @@ type Executable interface {
} }
// PrepareBaseCmd is meant for tendermint and other servers // PrepareBaseCmd is meant for tendermint and other servers
func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor {
func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor {
cobra.OnInitialize(func() { initEnv(envPrefix) }) cobra.OnInitialize(func() { initEnv(envPrefix) })
cmd.PersistentFlags().StringP(RootFlag, "r", defautRoot, "DEPRECATED. Use --home")
// -h is already reserved for --help as part of the cobra framework
// do you want to try something else??
// also, default must be empty, so we can detect this unset and fall back
// to --root / TM_ROOT / TMROOT
cmd.PersistentFlags().String(HomeFlag, "", "root directory for config and data")
cmd.PersistentFlags().StringP(HomeFlag, "", defaultHome, "directory for config and data")
cmd.PersistentFlags().Bool(TraceFlag, false, "print out full stack trace on errors") cmd.PersistentFlags().Bool(TraceFlag, false, "print out full stack trace on errors")
cmd.PersistentPreRunE = concatCobraCmdFuncs(bindFlagsLoadViper, cmd.PersistentPreRunE) cmd.PersistentPreRunE = concatCobraCmdFuncs(bindFlagsLoadViper, cmd.PersistentPreRunE)
return Executor{cmd, os.Exit} return Executor{cmd, os.Exit}
@ -45,11 +39,11 @@ func PrepareBaseCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor {
// //
// This adds --encoding (hex, btc, base64) and --output (text, json) to // This adds --encoding (hex, btc, base64) and --output (text, json) to
// the command. These only really make sense in interactive commands. // the command. These only really make sense in interactive commands.
func PrepareMainCmd(cmd *cobra.Command, envPrefix, defautRoot string) Executor {
func PrepareMainCmd(cmd *cobra.Command, envPrefix, defaultHome string) Executor {
cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)") cmd.PersistentFlags().StringP(EncodingFlag, "e", "hex", "Binary encoding (hex|b64|btc)")
cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)") cmd.PersistentFlags().StringP(OutputFlag, "o", "text", "Output format (text|json)")
cmd.PersistentPreRunE = concatCobraCmdFuncs(setEncoding, validateOutput, cmd.PersistentPreRunE) cmd.PersistentPreRunE = concatCobraCmdFuncs(setEncoding, validateOutput, cmd.PersistentPreRunE)
return PrepareBaseCmd(cmd, envPrefix, defautRoot)
return PrepareBaseCmd(cmd, envPrefix, defaultHome)
} }
// initEnv sets to use ENV variables if set. // initEnv sets to use ENV variables if set.
@ -136,17 +130,10 @@ func bindFlagsLoadViper(cmd *cobra.Command, args []string) error {
return err return err
} }
// rootDir is command line flag, env variable, or default $HOME/.tlc
// NOTE: we support both --root and --home for now, but eventually only --home
// Also ensure we set the correct rootDir under HomeFlag so we dont need to
// repeat this logic elsewhere.
rootDir := viper.GetString(HomeFlag)
if rootDir == "" {
rootDir = viper.GetString(RootFlag)
viper.Set(HomeFlag, rootDir)
}
homeDir := viper.GetString(HomeFlag)
viper.Set(HomeFlag, homeDir)
viper.SetConfigName("config") // name of config file (without extension) viper.SetConfigName("config") // name of config file (without extension)
viper.AddConfigPath(rootDir) // search root directory
viper.AddConfigPath(homeDir) // search root directory
// If a config file is found, read it in. // If a config file is found, read it in.
if err := viper.ReadInConfig(); err == nil { if err := viper.ReadInConfig(); err == nil {


+ 5
- 11
cli/setup_test.go View File

@ -57,12 +57,9 @@ func TestSetupEnv(t *testing.T) {
func TestSetupConfig(t *testing.T) { func TestSetupConfig(t *testing.T) {
// we pre-create two config files we can refer to in the rest of // we pre-create two config files we can refer to in the rest of
// the test cases. // the test cases.
cval1, cval2 := "fubble", "wubble"
cval1 := "fubble"
conf1, err := WriteDemoConfig(map[string]string{"boo": cval1}) conf1, err := WriteDemoConfig(map[string]string{"boo": cval1})
require.Nil(t, err) require.Nil(t, err)
// make sure it handles dashed-words in the config, and ignores random info
conf2, err := WriteDemoConfig(map[string]string{"boo": cval2, "foo": "bar", "two-words": "WORD"})
require.Nil(t, err)
cases := []struct { cases := []struct {
args []string args []string
@ -74,16 +71,13 @@ func TestSetupConfig(t *testing.T) {
// setting on the command line // setting on the command line
{[]string{"--boo", "haha"}, nil, "haha", ""}, {[]string{"--boo", "haha"}, nil, "haha", ""},
{[]string{"--two-words", "rocks"}, nil, "", "rocks"}, {[]string{"--two-words", "rocks"}, nil, "", "rocks"},
{[]string{"--root", conf1}, nil, cval1, ""},
{[]string{"--home", conf1}, nil, cval1, ""},
// test both variants of the prefix // test both variants of the prefix
{nil, map[string]string{"RD_BOO": "bang"}, "bang", ""}, {nil, map[string]string{"RD_BOO": "bang"}, "bang", ""},
{nil, map[string]string{"RD_TWO_WORDS": "fly"}, "", "fly"}, {nil, map[string]string{"RD_TWO_WORDS": "fly"}, "", "fly"},
{nil, map[string]string{"RDTWO_WORDS": "fly"}, "", "fly"}, {nil, map[string]string{"RDTWO_WORDS": "fly"}, "", "fly"},
{nil, map[string]string{"RD_ROOT": conf1}, cval1, ""},
{nil, map[string]string{"RDROOT": conf2}, cval2, "WORD"},
{nil, map[string]string{"RD_HOME": conf1}, cval1, ""},
{nil, map[string]string{"RDHOME": conf1}, cval1, ""}, {nil, map[string]string{"RDHOME": conf1}, cval1, ""},
// and when both are set??? HOME wins every time!
{[]string{"--root", conf1}, map[string]string{"RDHOME": conf2}, cval2, "WORD"},
} }
for idx, tc := range cases { for idx, tc := range cases {
@ -156,10 +150,10 @@ func TestSetupUnmarshal(t *testing.T) {
{nil, nil, c("", 0)}, {nil, nil, c("", 0)},
// setting on the command line // setting on the command line
{[]string{"--name", "haha"}, nil, c("haha", 0)}, {[]string{"--name", "haha"}, nil, c("haha", 0)},
{[]string{"--root", conf1}, nil, c(cval1, 0)},
{[]string{"--home", conf1}, nil, c(cval1, 0)},
// test both variants of the prefix // test both variants of the prefix
{nil, map[string]string{"MR_AGE": "56"}, c("", 56)}, {nil, map[string]string{"MR_AGE": "56"}, c("", 56)},
{nil, map[string]string{"MR_ROOT": conf1}, c(cval1, 0)},
{nil, map[string]string{"MR_HOME": conf1}, c(cval1, 0)},
{[]string{"--age", "17"}, map[string]string{"MRHOME": conf2}, c(cval2, 17)}, {[]string{"--age", "17"}, map[string]string{"MRHOME": conf2}, c(cval2, 17)},
} }


+ 152
- 94
clist/clist.go View File

@ -1,46 +1,68 @@
package clist package clist
/* /*
The purpose of CList is to provide a goroutine-safe linked-list. The purpose of CList is to provide a goroutine-safe linked-list.
This list can be traversed concurrently by any number of goroutines. This list can be traversed concurrently by any number of goroutines.
However, removed CElements cannot be added back. However, removed CElements cannot be added back.
NOTE: Not all methods of container/list are (yet) implemented. NOTE: Not all methods of container/list are (yet) implemented.
NOTE: Removed elements need to DetachPrev or DetachNext consistently NOTE: Removed elements need to DetachPrev or DetachNext consistently
to ensure garbage collection of removed elements. to ensure garbage collection of removed elements.
*/ */
import ( import (
"sync" "sync"
"sync/atomic"
"unsafe"
) )
// CElement is an element of a linked-list
// Traversal from a CElement are goroutine-safe.
/*
CElement is an element of a linked-list
Traversal from a CElement is goroutine-safe.
We can't avoid using WaitGroups or for-loops given the documentation
spec without re-implementing the primitives that already exist in
golang/sync. Notice that WaitGroup allows many go-routines to be
simultaneously released, which is what we want. Mutex doesn't do
this. RWMutex does this, but it's clumsy to use in the way that a
WaitGroup would be used -- and we'd end up having two RWMutex's for
prev/next each, which is doubly confusing.
sync.Cond would be sort-of useful, but we don't need a write-lock in
the for-loop. Use sync.Cond when you need serial access to the
"condition". In our case our condition is if `next != nil || removed`,
and there's no reason to serialize that condition for goroutines
waiting on NextWait() (since it's just a read operation).
*/
type CElement struct { type CElement struct {
prev unsafe.Pointer
mtx sync.RWMutex
prev *CElement
prevWg *sync.WaitGroup prevWg *sync.WaitGroup
next unsafe.Pointer
next *CElement
nextWg *sync.WaitGroup nextWg *sync.WaitGroup
removed uint32
Value interface{}
removed bool
Value interface{} // immutable
} }
// Blocking implementation of Next(). // Blocking implementation of Next().
// May return nil iff CElement was tail and got removed. // May return nil iff CElement was tail and got removed.
func (e *CElement) NextWait() *CElement { func (e *CElement) NextWait() *CElement {
for { for {
e.nextWg.Wait()
next := e.Next()
if next == nil {
if e.Removed() {
return nil
} else {
continue
}
} else {
e.mtx.RLock()
next := e.next
nextWg := e.nextWg
removed := e.removed
e.mtx.RUnlock()
if next != nil || removed {
return next return next
} }
nextWg.Wait()
// e.next doesn't necessarily exist here.
// That's why we need to continue a for-loop.
} }
} }
@ -48,82 +70,113 @@ func (e *CElement) NextWait() *CElement {
// May return nil iff CElement was head and got removed. // May return nil iff CElement was head and got removed.
func (e *CElement) PrevWait() *CElement { func (e *CElement) PrevWait() *CElement {
for { for {
e.prevWg.Wait()
prev := e.Prev()
if prev == nil {
if e.Removed() {
return nil
} else {
continue
}
} else {
e.mtx.RLock()
prev := e.prev
prevWg := e.prevWg
removed := e.removed
e.mtx.RUnlock()
if prev != nil || removed {
return prev return prev
} }
prevWg.Wait()
} }
} }
// Nonblocking, may return nil if at the end. // Nonblocking, may return nil if at the end.
func (e *CElement) Next() *CElement { func (e *CElement) Next() *CElement {
return (*CElement)(atomic.LoadPointer(&e.next))
e.mtx.RLock()
defer e.mtx.RUnlock()
return e.next
} }
// Nonblocking, may return nil if at the end. // Nonblocking, may return nil if at the end.
func (e *CElement) Prev() *CElement { func (e *CElement) Prev() *CElement {
return (*CElement)(atomic.LoadPointer(&e.prev))
e.mtx.RLock()
defer e.mtx.RUnlock()
return e.prev
} }
func (e *CElement) Removed() bool { func (e *CElement) Removed() bool {
return atomic.LoadUint32(&(e.removed)) > 0
e.mtx.RLock()
defer e.mtx.RUnlock()
return e.removed
} }
func (e *CElement) DetachNext() { func (e *CElement) DetachNext() {
if !e.Removed() { if !e.Removed() {
panic("DetachNext() must be called after Remove(e)") panic("DetachNext() must be called after Remove(e)")
} }
atomic.StorePointer(&e.next, nil)
e.mtx.Lock()
defer e.mtx.Unlock()
e.next = nil
} }
func (e *CElement) DetachPrev() { func (e *CElement) DetachPrev() {
if !e.Removed() { if !e.Removed() {
panic("DetachPrev() must be called after Remove(e)") panic("DetachPrev() must be called after Remove(e)")
} }
atomic.StorePointer(&e.prev, nil)
e.mtx.Lock()
defer e.mtx.Unlock()
e.prev = nil
} }
func (e *CElement) setNextAtomic(next *CElement) {
for {
oldNext := atomic.LoadPointer(&e.next)
if !atomic.CompareAndSwapPointer(&(e.next), oldNext, unsafe.Pointer(next)) {
continue
}
if next == nil && oldNext != nil { // We for-loop in NextWait() so race is ok
e.nextWg.Add(1)
}
if next != nil && oldNext == nil {
e.nextWg.Done()
}
return
// NOTE: This function needs to be safe for
// concurrent goroutines waiting on nextWg.
func (e *CElement) SetNext(newNext *CElement) {
e.mtx.Lock()
defer e.mtx.Unlock()
oldNext := e.next
e.next = newNext
if oldNext != nil && newNext == nil {
// See https://golang.org/pkg/sync/:
//
// If a WaitGroup is reused to wait for several independent sets of
// events, new Add calls must happen after all previous Wait calls have
// returned.
e.nextWg = waitGroup1() // WaitGroups are difficult to re-use.
}
if oldNext == nil && newNext != nil {
e.nextWg.Done()
} }
} }
func (e *CElement) setPrevAtomic(prev *CElement) {
for {
oldPrev := atomic.LoadPointer(&e.prev)
if !atomic.CompareAndSwapPointer(&(e.prev), oldPrev, unsafe.Pointer(prev)) {
continue
}
if prev == nil && oldPrev != nil { // We for-loop in PrevWait() so race is ok
e.prevWg.Add(1)
}
if prev != nil && oldPrev == nil {
e.prevWg.Done()
}
return
// NOTE: This function needs to be safe for
// concurrent goroutines waiting on prevWg
func (e *CElement) SetPrev(newPrev *CElement) {
e.mtx.Lock()
defer e.mtx.Unlock()
oldPrev := e.prev
e.prev = newPrev
if oldPrev != nil && newPrev == nil {
e.prevWg = waitGroup1() // WaitGroups are difficult to re-use.
}
if oldPrev == nil && newPrev != nil {
e.prevWg.Done()
} }
} }
func (e *CElement) setRemovedAtomic() {
atomic.StoreUint32(&(e.removed), 1)
func (e *CElement) SetRemoved() {
e.mtx.Lock()
defer e.mtx.Unlock()
e.removed = true
// This wakes up anyone waiting in either direction.
if e.prev == nil {
e.prevWg.Done()
}
if e.next == nil {
e.nextWg.Done()
}
} }
//-------------------------------------------------------------------------------- //--------------------------------------------------------------------------------
@ -132,7 +185,7 @@ func (e *CElement) setRemovedAtomic() {
// The zero value for CList is an empty list ready to use. // The zero value for CList is an empty list ready to use.
// Operations are goroutine-safe. // Operations are goroutine-safe.
type CList struct { type CList struct {
mtx sync.Mutex
mtx sync.RWMutex
wg *sync.WaitGroup wg *sync.WaitGroup
head *CElement // first element head *CElement // first element
tail *CElement // last element tail *CElement // last element
@ -142,6 +195,7 @@ type CList struct {
func (l *CList) Init() *CList { func (l *CList) Init() *CList {
l.mtx.Lock() l.mtx.Lock()
defer l.mtx.Unlock() defer l.mtx.Unlock()
l.wg = waitGroup1() l.wg = waitGroup1()
l.head = nil l.head = nil
l.tail = nil l.tail = nil
@ -152,48 +206,55 @@ func (l *CList) Init() *CList {
func New() *CList { return new(CList).Init() } func New() *CList { return new(CList).Init() }
func (l *CList) Len() int { func (l *CList) Len() int {
l.mtx.Lock()
defer l.mtx.Unlock()
l.mtx.RLock()
defer l.mtx.RUnlock()
return l.len return l.len
} }
func (l *CList) Front() *CElement { func (l *CList) Front() *CElement {
l.mtx.Lock()
defer l.mtx.Unlock()
l.mtx.RLock()
defer l.mtx.RUnlock()
return l.head return l.head
} }
func (l *CList) FrontWait() *CElement { func (l *CList) FrontWait() *CElement {
// Loop until the head is non-nil else wait and try again
for { for {
l.mtx.Lock()
l.mtx.RLock()
head := l.head head := l.head
wg := l.wg wg := l.wg
l.mtx.Unlock()
if head == nil {
wg.Wait()
} else {
l.mtx.RUnlock()
if head != nil {
return head return head
} }
wg.Wait()
// NOTE: If you think l.head exists here, think harder.
} }
} }
func (l *CList) Back() *CElement { func (l *CList) Back() *CElement {
l.mtx.Lock()
defer l.mtx.Unlock()
l.mtx.RLock()
defer l.mtx.RUnlock()
return l.tail return l.tail
} }
func (l *CList) BackWait() *CElement { func (l *CList) BackWait() *CElement {
for { for {
l.mtx.Lock()
l.mtx.RLock()
tail := l.tail tail := l.tail
wg := l.wg wg := l.wg
l.mtx.Unlock()
if tail == nil {
wg.Wait()
} else {
l.mtx.RUnlock()
if tail != nil {
return tail return tail
} }
wg.Wait()
// l.tail doesn't necessarily exist here.
// That's why we need to continue a for-loop.
} }
} }
@ -203,11 +264,12 @@ func (l *CList) PushBack(v interface{}) *CElement {
// Construct a new element // Construct a new element
e := &CElement{ e := &CElement{
prev: nil,
prevWg: waitGroup1(),
next: nil,
nextWg: waitGroup1(),
Value: v,
prev: nil,
prevWg: waitGroup1(),
next: nil,
nextWg: waitGroup1(),
removed: false,
Value: v,
} }
// Release waiters on FrontWait/BackWait maybe // Release waiters on FrontWait/BackWait maybe
@ -221,9 +283,9 @@ func (l *CList) PushBack(v interface{}) *CElement {
l.head = e l.head = e
l.tail = e l.tail = e
} else { } else {
l.tail.setNextAtomic(e)
e.setPrevAtomic(l.tail)
l.tail = e
e.SetPrev(l.tail) // We must init e first.
l.tail.SetNext(e) // This will make e accessible.
l.tail = e // Update the list.
} }
return e return e
@ -250,30 +312,26 @@ func (l *CList) Remove(e *CElement) interface{} {
// If we're removing the only item, make CList FrontWait/BackWait wait. // If we're removing the only item, make CList FrontWait/BackWait wait.
if l.len == 1 { if l.len == 1 {
l.wg.Add(1)
l.wg = waitGroup1() // WaitGroups are difficult to re-use.
} }
// Update l.len
l.len -= 1 l.len -= 1
// Connect next/prev and set head/tail // Connect next/prev and set head/tail
if prev == nil { if prev == nil {
l.head = next l.head = next
} else { } else {
prev.setNextAtomic(next)
prev.SetNext(next)
} }
if next == nil { if next == nil {
l.tail = prev l.tail = prev
} else { } else {
next.setPrevAtomic(prev)
next.SetPrev(prev)
} }
// Set .Done() on e, otherwise waiters will wait forever. // Set .Done() on e, otherwise waiters will wait forever.
e.setRemovedAtomic()
if prev == nil {
e.prevWg.Done()
}
if next == nil {
e.nextWg.Done()
}
e.SetRemoved()
return e.Value return e.Value
} }


+ 3
- 4
common/bit_array.go View File

@ -3,7 +3,6 @@ package common
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"math/rand"
"strings" "strings"
"sync" "sync"
) )
@ -212,12 +211,12 @@ func (bA *BitArray) PickRandom() (int, bool) {
if length == 0 { if length == 0 {
return 0, false return 0, false
} }
randElemStart := rand.Intn(length)
randElemStart := RandIntn(length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
elemIdx := ((i + randElemStart) % length) elemIdx := ((i + randElemStart) % length)
if elemIdx < length-1 { if elemIdx < length-1 {
if bA.Elems[elemIdx] > 0 { if bA.Elems[elemIdx] > 0 {
randBitStart := rand.Intn(64)
randBitStart := RandIntn(64)
for j := 0; j < 64; j++ { for j := 0; j < 64; j++ {
bitIdx := ((j + randBitStart) % 64) bitIdx := ((j + randBitStart) % 64)
if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 {
@ -232,7 +231,7 @@ func (bA *BitArray) PickRandom() (int, bool) {
if elemBits == 0 { if elemBits == 0 {
elemBits = 64 elemBits = 64
} }
randBitStart := rand.Intn(elemBits)
randBitStart := RandIntn(elemBits)
for j := 0; j < elemBits; j++ { for j := 0; j < elemBits; j++ {
bitIdx := ((j + randBitStart) % elemBits) bitIdx := ((j + randBitStart) % elemBits)
if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 { if (bA.Elems[elemIdx] & (uint64(1) << uint(bitIdx))) > 0 {


+ 100
- 21
common/random.go View File

@ -2,7 +2,8 @@ package common
import ( import (
crand "crypto/rand" crand "crypto/rand"
"math/rand"
mrand "math/rand"
"sync"
"time" "time"
) )
@ -10,22 +11,36 @@ const (
strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters strChars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" // 62 characters
) )
func init() {
// pseudo random number generator.
// seeded with OS randomness (crand)
var prng struct {
sync.Mutex
*mrand.Rand
}
func reset() {
b := cRandBytes(8) b := cRandBytes(8)
var seed uint64 var seed uint64
for i := 0; i < 8; i++ { for i := 0; i < 8; i++ {
seed |= uint64(b[i]) seed |= uint64(b[i])
seed <<= 8 seed <<= 8
} }
rand.Seed(int64(seed))
prng.Lock()
prng.Rand = mrand.New(mrand.NewSource(int64(seed)))
prng.Unlock()
}
func init() {
reset()
} }
// Constructs an alphanumeric string of given length. // Constructs an alphanumeric string of given length.
// It is not safe for cryptographic usage.
func RandStr(length int) string { func RandStr(length int) string {
chars := []byte{} chars := []byte{}
MAIN_LOOP: MAIN_LOOP:
for { for {
val := rand.Int63()
val := RandInt63()
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
v := int(val & 0x3f) // rightmost 6 bits v := int(val & 0x3f) // rightmost 6 bits
if v >= 62 { // only 62 characters in strChars if v >= 62 { // only 62 characters in strChars
@ -44,87 +59,151 @@ MAIN_LOOP:
return string(chars) return string(chars)
} }
// It is not safe for cryptographic usage.
func RandUint16() uint16 { func RandUint16() uint16 {
return uint16(rand.Uint32() & (1<<16 - 1))
return uint16(RandUint32() & (1<<16 - 1))
} }
// It is not safe for cryptographic usage.
func RandUint32() uint32 { func RandUint32() uint32 {
return rand.Uint32()
prng.Lock()
u32 := prng.Uint32()
prng.Unlock()
return u32
} }
// It is not safe for cryptographic usage.
func RandUint64() uint64 { func RandUint64() uint64 {
return uint64(rand.Uint32())<<32 + uint64(rand.Uint32())
return uint64(RandUint32())<<32 + uint64(RandUint32())
} }
// It is not safe for cryptographic usage.
func RandUint() uint { func RandUint() uint {
return uint(rand.Int())
prng.Lock()
i := prng.Int()
prng.Unlock()
return uint(i)
} }
// It is not safe for cryptographic usage.
func RandInt16() int16 { func RandInt16() int16 {
return int16(rand.Uint32() & (1<<16 - 1))
return int16(RandUint32() & (1<<16 - 1))
} }
// It is not safe for cryptographic usage.
func RandInt32() int32 { func RandInt32() int32 {
return int32(rand.Uint32())
return int32(RandUint32())
} }
// It is not safe for cryptographic usage.
func RandInt64() int64 { func RandInt64() int64 {
return int64(rand.Uint32())<<32 + int64(rand.Uint32())
return int64(RandUint64())
} }
// It is not safe for cryptographic usage.
func RandInt() int { func RandInt() int {
return rand.Int()
prng.Lock()
i := prng.Int()
prng.Unlock()
return i
}
// It is not safe for cryptographic usage.
func RandInt31() int32 {
prng.Lock()
i31 := prng.Int31()
prng.Unlock()
return i31
}
// It is not safe for cryptographic usage.
func RandInt63() int64 {
prng.Lock()
i63 := prng.Int63()
prng.Unlock()
return i63
} }
// Distributed pseudo-exponentially to test for various cases // Distributed pseudo-exponentially to test for various cases
// It is not safe for cryptographic usage.
func RandUint16Exp() uint16 { func RandUint16Exp() uint16 {
bits := rand.Uint32() % 16
bits := RandUint32() % 16
if bits == 0 { if bits == 0 {
return 0 return 0
} }
n := uint16(1 << (bits - 1)) n := uint16(1 << (bits - 1))
n += uint16(rand.Int31()) & ((1 << (bits - 1)) - 1)
n += uint16(RandInt31()) & ((1 << (bits - 1)) - 1)
return n return n
} }
// Distributed pseudo-exponentially to test for various cases // Distributed pseudo-exponentially to test for various cases
// It is not safe for cryptographic usage.
func RandUint32Exp() uint32 { func RandUint32Exp() uint32 {
bits := rand.Uint32() % 32
bits := RandUint32() % 32
if bits == 0 { if bits == 0 {
return 0 return 0
} }
n := uint32(1 << (bits - 1)) n := uint32(1 << (bits - 1))
n += uint32(rand.Int31()) & ((1 << (bits - 1)) - 1)
n += uint32(RandInt31()) & ((1 << (bits - 1)) - 1)
return n return n
} }
// Distributed pseudo-exponentially to test for various cases // Distributed pseudo-exponentially to test for various cases
// It is not safe for cryptographic usage.
func RandUint64Exp() uint64 { func RandUint64Exp() uint64 {
bits := rand.Uint32() % 64
bits := RandUint32() % 64
if bits == 0 { if bits == 0 {
return 0 return 0
} }
n := uint64(1 << (bits - 1)) n := uint64(1 << (bits - 1))
n += uint64(rand.Int63()) & ((1 << (bits - 1)) - 1)
n += uint64(RandInt63()) & ((1 << (bits - 1)) - 1)
return n return n
} }
// It is not safe for cryptographic usage.
func RandFloat32() float32 { func RandFloat32() float32 {
return rand.Float32()
prng.Lock()
f32 := prng.Float32()
prng.Unlock()
return f32
} }
// It is not safe for cryptographic usage.
func RandTime() time.Time { func RandTime() time.Time {
return time.Unix(int64(RandUint64Exp()), 0) return time.Unix(int64(RandUint64Exp()), 0)
} }
// RandBytes returns n random bytes from the OS's source of entropy ie. via crypto/rand.
// It is not safe for cryptographic usage.
func RandBytes(n int) []byte { func RandBytes(n int) []byte {
// cRandBytes isn't guaranteed to be fast so instead
// use random bytes generated from the internal PRNG
bs := make([]byte, n) bs := make([]byte, n)
for i := 0; i < n; i++ {
bs[i] = byte(rand.Intn(256))
for i := 0; i < len(bs); i++ {
bs[i] = byte(RandInt() & 0xFF)
} }
return bs return bs
} }
// RandIntn returns, as an int, a non-negative pseudo-random number in [0, n).
// It panics if n <= 0.
// It is not safe for cryptographic usage.
func RandIntn(n int) int {
prng.Lock()
i := prng.Intn(n)
prng.Unlock()
return i
}
// RandPerm returns a pseudo-random permutation of n integers in [0, n).
// It is not safe for cryptographic usage.
func RandPerm(n int) []int {
prng.Lock()
perm := prng.Perm(n)
prng.Unlock()
return perm
}
// NOTE: This relies on the os's random number generator. // NOTE: This relies on the os's random number generator.
// For real security, we should salt that with some seed. // For real security, we should salt that with some seed.
// See github.com/tendermint/go-crypto for a more secure reader. // See github.com/tendermint/go-crypto for a more secure reader.


+ 120
- 0
common/random_test.go View File

@ -0,0 +1,120 @@
package common
import (
"bytes"
"encoding/json"
"fmt"
"io"
mrand "math/rand"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestRandStr(t *testing.T) {
l := 243
s := RandStr(l)
assert.Equal(t, l, len(s))
}
func TestRandBytes(t *testing.T) {
l := 243
b := RandBytes(l)
assert.Equal(t, l, len(b))
}
func TestRandIntn(t *testing.T) {
n := 243
for i := 0; i < 100; i++ {
x := RandIntn(n)
assert.True(t, x < n)
}
}
// It is essential that these tests run and never repeat their outputs
// lest we've been pwned and the behavior of our randomness is controlled.
// See Issues:
// * https://github.com/tendermint/tmlibs/issues/99
// * https://github.com/tendermint/tendermint/issues/973
func TestUniqueRng(t *testing.T) {
buf := new(bytes.Buffer)
outputs := make(map[string][]int)
for i := 0; i < 100; i++ {
testThemAll(buf)
output := buf.String()
buf.Reset()
runs, seen := outputs[output]
if seen {
t.Errorf("Run #%d's output was already seen in previous runs: %v", i, runs)
}
outputs[output] = append(outputs[output], i)
}
}
func testThemAll(out io.Writer) {
// Reset the internal PRNG
reset()
// Set math/rand's Seed so that any direct invocations
// of math/rand will reveal themselves.
mrand.Seed(1)
perm := RandPerm(10)
blob, _ := json.Marshal(perm)
fmt.Fprintf(out, "perm: %s\n", blob)
fmt.Fprintf(out, "randInt: %d\n", RandInt())
fmt.Fprintf(out, "randUint: %d\n", RandUint())
fmt.Fprintf(out, "randIntn: %d\n", RandIntn(97))
fmt.Fprintf(out, "randInt31: %d\n", RandInt31())
fmt.Fprintf(out, "randInt32: %d\n", RandInt32())
fmt.Fprintf(out, "randInt63: %d\n", RandInt63())
fmt.Fprintf(out, "randInt64: %d\n", RandInt64())
fmt.Fprintf(out, "randUint32: %d\n", RandUint32())
fmt.Fprintf(out, "randUint64: %d\n", RandUint64())
fmt.Fprintf(out, "randUint16Exp: %d\n", RandUint16Exp())
fmt.Fprintf(out, "randUint32Exp: %d\n", RandUint32Exp())
fmt.Fprintf(out, "randUint64Exp: %d\n", RandUint64Exp())
}
func TestRngConcurrencySafety(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = RandUint64()
<-time.After(time.Millisecond * time.Duration(RandIntn(100)))
_ = RandPerm(3)
}()
}
wg.Wait()
}
func BenchmarkRandBytes10B(b *testing.B) {
benchmarkRandBytes(b, 10)
}
func BenchmarkRandBytes100B(b *testing.B) {
benchmarkRandBytes(b, 100)
}
func BenchmarkRandBytes1KiB(b *testing.B) {
benchmarkRandBytes(b, 1024)
}
func BenchmarkRandBytes10KiB(b *testing.B) {
benchmarkRandBytes(b, 10*1024)
}
func BenchmarkRandBytes100KiB(b *testing.B) {
benchmarkRandBytes(b, 100*1024)
}
func BenchmarkRandBytes1MiB(b *testing.B) {
benchmarkRandBytes(b, 1024*1024)
}
func benchmarkRandBytes(b *testing.B, n int) {
for i := 0; i < b.N; i++ {
_ = RandBytes(n)
}
b.ReportAllocs()
}

+ 184
- 42
common/repeat_timer.go View File

@ -5,82 +5,224 @@ import (
"time" "time"
) )
// Used by RepeatTimer the first time,
// and every time it's Reset() after Stop().
type TickerMaker func(dur time.Duration) Ticker
// Ticker is a basic ticker interface.
type Ticker interface {
// Never changes, never closes.
Chan() <-chan time.Time
// Stopping a stopped Ticker will panic.
Stop()
}
//----------------------------------------
// defaultTickerMaker
func defaultTickerMaker(dur time.Duration) Ticker {
ticker := time.NewTicker(dur)
return (*defaultTicker)(ticker)
}
type defaultTicker time.Ticker
// Implements Ticker
func (t *defaultTicker) Chan() <-chan time.Time {
return t.C
}
// Implements Ticker
func (t *defaultTicker) Stop() {
((*time.Ticker)(t)).Stop()
}
//----------------------------------------
// LogicalTickerMaker
// Construct a TickerMaker that always uses `source`.
// It's useful for simulating a deterministic clock.
func NewLogicalTickerMaker(source chan time.Time) TickerMaker {
return func(dur time.Duration) Ticker {
return newLogicalTicker(source, dur)
}
}
type logicalTicker struct {
source <-chan time.Time
ch chan time.Time
quit chan struct{}
}
func newLogicalTicker(source <-chan time.Time, interval time.Duration) Ticker {
lt := &logicalTicker{
source: source,
ch: make(chan time.Time),
quit: make(chan struct{}),
}
go lt.fireRoutine(interval)
return lt
}
// We need a goroutine to read times from t.source
// and fire on t.Chan() when `interval` has passed.
func (t *logicalTicker) fireRoutine(interval time.Duration) {
source := t.source
// Init `lasttime`
lasttime := time.Time{}
select {
case lasttime = <-source:
case <-t.quit:
return
}
// Init `lasttime` end
timeleft := interval
for {
select {
case newtime := <-source:
elapsed := newtime.Sub(lasttime)
timeleft -= elapsed
if timeleft <= 0 {
// Block for determinism until the ticker is stopped.
select {
case t.ch <- newtime:
case <-t.quit:
return
}
// Reset timeleft.
// Don't try to "catch up" by sending more.
// "Ticker adjusts the intervals or drops ticks to make up for
// slow receivers" - https://golang.org/pkg/time/#Ticker
timeleft = interval
}
case <-t.quit:
return // done
}
}
}
// Implements Ticker
func (t *logicalTicker) Chan() <-chan time.Time {
return t.ch // immutable
}
// Implements Ticker
func (t *logicalTicker) Stop() {
close(t.quit) // it *should* panic when stopped twice.
}
//---------------------------------------------------------------------
/* /*
RepeatTimer repeatedly sends a struct{}{} to .Ch after each "dur" period.
It's good for keeping connections alive.
A RepeatTimer must be Stop()'d or it will keep a goroutine alive.
RepeatTimer repeatedly sends a struct{}{} to `.Chan()` after each `dur`
period. (It's good for keeping connections alive.)
A RepeatTimer must be stopped, or it will keep a goroutine alive.
*/ */
type RepeatTimer struct { type RepeatTimer struct {
Ch chan time.Time
name string
ch chan time.Time
tm TickerMaker
mtx sync.Mutex mtx sync.Mutex
name string
ticker *time.Ticker
quit chan struct{}
wg *sync.WaitGroup
dur time.Duration dur time.Duration
ticker Ticker
quit chan struct{}
} }
// NewRepeatTimer returns a RepeatTimer with a defaultTicker.
func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer { func NewRepeatTimer(name string, dur time.Duration) *RepeatTimer {
return NewRepeatTimerWithTickerMaker(name, dur, defaultTickerMaker)
}
// NewRepeatTimerWithTicker returns a RepeatTimer with the given ticker
// maker.
func NewRepeatTimerWithTickerMaker(name string, dur time.Duration, tm TickerMaker) *RepeatTimer {
var t = &RepeatTimer{ var t = &RepeatTimer{
Ch: make(chan time.Time),
ticker: time.NewTicker(dur),
quit: make(chan struct{}),
wg: new(sync.WaitGroup),
name: name, name: name,
ch: make(chan time.Time),
tm: tm,
dur: dur, dur: dur,
ticker: nil,
quit: nil,
} }
t.wg.Add(1)
go t.fireRoutine(t.ticker)
t.reset()
return t return t
} }
func (t *RepeatTimer) fireRoutine(ticker *time.Ticker) {
func (t *RepeatTimer) fireRoutine(ch <-chan time.Time, quit <-chan struct{}) {
for { for {
select { select {
case t_ := <-ticker.C:
t.Ch <- t_
case <-t.quit:
// needed so we know when we can reset t.quit
t.wg.Done()
case t_ := <-ch:
t.ch <- t_
case <-quit: // NOTE: `t.quit` races.
return return
} }
} }
} }
func (t *RepeatTimer) Chan() <-chan time.Time {
return t.ch
}
func (t *RepeatTimer) Stop() {
t.mtx.Lock()
defer t.mtx.Unlock()
t.stop()
}
// Wait the duration again before firing. // Wait the duration again before firing.
func (t *RepeatTimer) Reset() { func (t *RepeatTimer) Reset() {
t.Stop()
t.mtx.Lock() // Lock
t.mtx.Lock()
defer t.mtx.Unlock() defer t.mtx.Unlock()
t.ticker = time.NewTicker(t.dur)
t.reset()
}
//----------------------------------------
// Misc.
// CONTRACT: (non-constructor) caller should hold t.mtx.
func (t *RepeatTimer) reset() {
if t.ticker != nil {
t.stop()
}
t.ticker = t.tm(t.dur)
t.quit = make(chan struct{}) t.quit = make(chan struct{})
t.wg.Add(1)
go t.fireRoutine(t.ticker)
go t.fireRoutine(t.ticker.Chan(), t.quit)
} }
// For ease of .Stop()'ing services before .Start()'ing them,
// we ignore .Stop()'s on nil RepeatTimers.
func (t *RepeatTimer) Stop() bool {
if t == nil {
return false
// CONTRACT: caller should hold t.mtx.
func (t *RepeatTimer) stop() {
if t.ticker == nil {
/*
Similar to the case of closing channels twice:
https://groups.google.com/forum/#!topic/golang-nuts/rhxMiNmRAPk
Stopping a RepeatTimer twice implies that you do
not know whether you are done or not.
If you're calling stop on a stopped RepeatTimer,
you probably have race conditions.
*/
panic("Tried to stop a stopped RepeatTimer")
} }
t.mtx.Lock() // Lock
defer t.mtx.Unlock()
t.ticker.Stop()
t.ticker = nil
/*
XXX
From https://golang.org/pkg/time/#Ticker:
"Stop the ticker to release associated resources"
"After Stop, no more ticks will be sent"
So we shouldn't have to do the below.
exists := t.ticker != nil
if exists {
t.ticker.Stop() // does not close the channel
select { select {
case <-t.Ch:
case <-t.ch:
// read off channel if there's anything there // read off channel if there's anything there
default: default:
} }
close(t.quit)
t.wg.Wait() // must wait for quit to close else we race Reset
t.ticker = nil
}
return exists
*/
close(t.quit)
} }

+ 68
- 54
common/repeat_timer_test.go View File

@ -1,78 +1,92 @@
package common package common
import ( import (
"sync"
"testing" "testing"
"time" "time"
// make govet noshadow happy...
asrt "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/assert"
) )
type rCounter struct {
input chan time.Time
mtx sync.Mutex
count int
func TestDefaultTicker(t *testing.T) {
ticker := defaultTickerMaker(time.Millisecond * 10)
<-ticker.Chan()
ticker.Stop()
} }
func (c *rCounter) Increment() {
c.mtx.Lock()
c.count++
c.mtx.Unlock()
}
func TestRepeat(t *testing.T) {
func (c *rCounter) Count() int {
c.mtx.Lock()
val := c.count
c.mtx.Unlock()
return val
}
ch := make(chan time.Time, 100)
lt := time.Time{} // zero time is year 1
// Read should run in a go-routine and
// updates count by one every time a packet comes in
func (c *rCounter) Read() {
for range c.input {
c.Increment()
// tick fires `cnt` times for each second.
tick := func(cnt int) {
for i := 0; i < cnt; i++ {
lt = lt.Add(time.Second)
ch <- lt
}
} }
}
func TestRepeat(test *testing.T) {
assert := asrt.New(test)
dur := time.Duration(50) * time.Millisecond
short := time.Duration(20) * time.Millisecond
// delay waits for cnt durations, an a little extra
delay := func(cnt int) time.Duration {
return time.Duration(cnt)*dur + time.Millisecond
// tock consumes Ticker.Chan() events `cnt` times.
tock := func(t *testing.T, rt *RepeatTimer, cnt int) {
for i := 0; i < cnt; i++ {
timeout := time.After(time.Second * 10)
select {
case <-rt.Chan():
case <-timeout:
panic("expected RepeatTimer to fire")
}
}
done := true
select {
case <-rt.Chan():
done = false
default:
}
assert.True(t, done)
} }
t := NewRepeatTimer("bar", dur)
// start at 0
c := &rCounter{input: t.Ch}
go c.Read()
assert.Equal(0, c.Count())
tm := NewLogicalTickerMaker(ch)
dur := time.Duration(10 * time.Millisecond) // less than a second
rt := NewRepeatTimerWithTickerMaker("bar", dur, tm)
// Start at 0.
tock(t, rt, 0)
tick(1) // init time
// wait for 4 periods
time.Sleep(delay(4))
assert.Equal(4, c.Count())
tock(t, rt, 0)
tick(1) // wait 1 periods
tock(t, rt, 1)
tick(2) // wait 2 periods
tock(t, rt, 2)
tick(3) // wait 3 periods
tock(t, rt, 3)
tick(4) // wait 4 periods
tock(t, rt, 4)
// keep reseting leads to no firing
// Multiple resets leads to no firing.
for i := 0; i < 20; i++ { for i := 0; i < 20; i++ {
time.Sleep(short)
t.Reset()
time.Sleep(time.Millisecond)
rt.Reset()
} }
assert.Equal(4, c.Count())
// after this, it still works normal
time.Sleep(delay(2))
assert.Equal(6, c.Count())
// After this, it works as new.
tock(t, rt, 0)
tick(1) // init time
tock(t, rt, 0)
tick(1) // wait 1 periods
tock(t, rt, 1)
tick(2) // wait 2 periods
tock(t, rt, 2)
tick(3) // wait 3 periods
tock(t, rt, 3)
tick(4) // wait 4 periods
tock(t, rt, 4)
// after a stop, nothing more is sent
stopped := t.Stop()
assert.True(stopped)
time.Sleep(delay(7))
assert.Equal(6, c.Count())
// After a stop, nothing more is sent.
rt.Stop()
tock(t, rt, 0)
// close channel to stop counter
close(t.Ch)
// Another stop panics.
assert.Panics(t, func() { rt.Stop() })
} }

+ 56
- 13
pubsub/pubsub.go View File

@ -13,6 +13,8 @@ package pubsub
import ( import (
"context" "context"
"errors"
"sync"
cmn "github.com/tendermint/tmlibs/common" cmn "github.com/tendermint/tmlibs/common"
) )
@ -38,6 +40,7 @@ type cmd struct {
// Query defines an interface for a query to be used for subscribing. // Query defines an interface for a query to be used for subscribing.
type Query interface { type Query interface {
Matches(tags map[string]interface{}) bool Matches(tags map[string]interface{}) bool
String() string
} }
// Server allows clients to subscribe/unsubscribe for messages, publishing // Server allows clients to subscribe/unsubscribe for messages, publishing
@ -47,6 +50,9 @@ type Server struct {
cmds chan cmd cmds chan cmd
cmdsCap int cmdsCap int
mtx sync.RWMutex
subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{}
} }
// Option sets a parameter for the server. // Option sets a parameter for the server.
@ -56,7 +62,9 @@ type Option func(*Server)
// for a detailed description of how to configure buffering. If no options are // for a detailed description of how to configure buffering. If no options are
// provided, the resulting server's queue is unbuffered. // provided, the resulting server's queue is unbuffered.
func NewServer(options ...Option) *Server { func NewServer(options ...Option) *Server {
s := &Server{}
s := &Server{
subscriptions: make(map[string]map[string]struct{}),
}
s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) s.BaseService = *cmn.NewBaseService(nil, "PubSub", s)
for _, option := range options { for _, option := range options {
@ -82,17 +90,33 @@ func BufferCapacity(cap int) Option {
} }
// BufferCapacity returns capacity of the internal server's queue. // BufferCapacity returns capacity of the internal server's queue.
func (s Server) BufferCapacity() int {
func (s *Server) BufferCapacity() int {
return s.cmdsCap return s.cmdsCap
} }
// Subscribe creates a subscription for the given client. It accepts a channel // Subscribe creates a subscription for the given client. It accepts a channel
// on which messages matching the given query can be received. If the
// subscription already exists, the old channel will be closed. An error will
// be returned to the caller if the context is canceled.
// on which messages matching the given query can be received. An error will be
// returned to the caller if the context is canceled or if subscription already
// exist for pair clientID and query.
func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error { func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, out chan<- interface{}) error {
s.mtx.RLock()
clientSubscriptions, ok := s.subscriptions[clientID]
if ok {
_, ok = clientSubscriptions[query.String()]
}
s.mtx.RUnlock()
if ok {
return errors.New("already subscribed")
}
select { select {
case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}:
s.mtx.Lock()
if _, ok = s.subscriptions[clientID]; !ok {
s.subscriptions[clientID] = make(map[string]struct{})
}
s.subscriptions[clientID][query.String()] = struct{}{}
s.mtx.Unlock()
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -100,10 +124,24 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou
} }
// Unsubscribe removes the subscription on the given query. An error will be // Unsubscribe removes the subscription on the given query. An error will be
// returned to the caller if the context is canceled.
// returned to the caller if the context is canceled or if subscription does
// not exist.
func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error {
s.mtx.RLock()
clientSubscriptions, ok := s.subscriptions[clientID]
if ok {
_, ok = clientSubscriptions[query.String()]
}
s.mtx.RUnlock()
if !ok {
return errors.New("subscription not found")
}
select { select {
case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}:
s.mtx.Lock()
delete(clientSubscriptions, query.String())
s.mtx.Unlock()
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -111,10 +149,20 @@ func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query)
} }
// UnsubscribeAll removes all client subscriptions. An error will be returned // UnsubscribeAll removes all client subscriptions. An error will be returned
// to the caller if the context is canceled.
// to the caller if the context is canceled or if subscription does not exist.
func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error {
s.mtx.RLock()
_, ok := s.subscriptions[clientID]
s.mtx.RUnlock()
if !ok {
return errors.New("subscription not found")
}
select { select {
case s.cmds <- cmd{op: unsub, clientID: clientID}: case s.cmds <- cmd{op: unsub, clientID: clientID}:
s.mtx.Lock()
delete(s.subscriptions, clientID)
s.mtx.Unlock()
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -186,13 +234,8 @@ loop:
func (state *state) add(clientID string, q Query, ch chan<- interface{}) { func (state *state) add(clientID string, q Query, ch chan<- interface{}) {
// add query if needed // add query if needed
if clientToChannelMap, ok := state.queries[q]; !ok {
if _, ok := state.queries[q]; !ok {
state.queries[q] = make(map[string]chan<- interface{}) state.queries[q] = make(map[string]chan<- interface{})
} else {
// check if already subscribed
if oldCh, ok := clientToChannelMap[clientID]; ok {
close(oldCh)
}
} }
// create subscription // create subscription


+ 25
- 7
pubsub/pubsub_test.go View File

@ -86,14 +86,11 @@ func TestClientSubscribesTwice(t *testing.T) {
ch2 := make(chan interface{}, 1) ch2 := make(chan interface{}, 1)
err = s.Subscribe(ctx, clientID, q, ch2) err = s.Subscribe(ctx, clientID, q, ch2)
require.NoError(t, err)
_, ok := <-ch1
assert.False(t, ok)
require.Error(t, err)
err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"}) err = s.PublishWithTags(ctx, "Spider-Man", map[string]interface{}{"tm.events.type": "NewBlock"})
require.NoError(t, err) require.NoError(t, err)
assertReceive(t, "Spider-Man", ch2)
assertReceive(t, "Spider-Man", ch1)
} }
func TestUnsubscribe(t *testing.T) { func TestUnsubscribe(t *testing.T) {
@ -117,6 +114,27 @@ func TestUnsubscribe(t *testing.T) {
assert.False(t, ok) assert.False(t, ok)
} }
func TestResubscribe(t *testing.T) {
s := pubsub.NewServer()
s.SetLogger(log.TestingLogger())
s.Start()
defer s.Stop()
ctx := context.Background()
ch := make(chan interface{})
err := s.Subscribe(ctx, clientID, query.Empty{}, ch)
require.NoError(t, err)
err = s.Unsubscribe(ctx, clientID, query.Empty{})
require.NoError(t, err)
ch = make(chan interface{})
err = s.Subscribe(ctx, clientID, query.Empty{}, ch)
require.NoError(t, err)
err = s.Publish(ctx, "Cable")
require.NoError(t, err)
assertReceive(t, "Cable", ch)
}
func TestUnsubscribeAll(t *testing.T) { func TestUnsubscribeAll(t *testing.T) {
s := pubsub.NewServer() s := pubsub.NewServer()
s.SetLogger(log.TestingLogger()) s.SetLogger(log.TestingLogger())
@ -125,9 +143,9 @@ func TestUnsubscribeAll(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1) ch1, ch2 := make(chan interface{}, 1), make(chan interface{}, 1)
err := s.Subscribe(ctx, clientID, query.Empty{}, ch1)
err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch1)
require.NoError(t, err) require.NoError(t, err)
err = s.Subscribe(ctx, clientID, query.Empty{}, ch2)
err = s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlockHeader'"), ch2)
require.NoError(t, err) require.NoError(t, err)
err = s.UnsubscribeAll(ctx, clientID) err = s.UnsubscribeAll(ctx, clientID)


+ 6
- 6
test.sh View File

@ -2,14 +2,14 @@
set -e set -e
# run the linter # run the linter
make metalinter_test
# make metalinter_test
# run the unit tests with coverage # run the unit tests with coverage
echo "" > coverage.txt echo "" > coverage.txt
for d in $(go list ./... | grep -v vendor); do for d in $(go list ./... | grep -v vendor); do
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
go test -race -coverprofile=profile.out -covermode=atomic "$d"
if [ -f profile.out ]; then
cat profile.out >> coverage.txt
rm profile.out
fi
done done

+ 1
- 1
version/version.go View File

@ -1,3 +1,3 @@
package version package version
const Version = "0.5.0"
const Version = "0.6.0"

Loading…
Cancel
Save