diff --git a/common/repeat_timer_test.go b/common/repeat_timer_test.go index 9f03f41df..87f34b950 100644 --- a/common/repeat_timer_test.go +++ b/common/repeat_timer_test.go @@ -43,7 +43,7 @@ func TestRepeat(test *testing.T) { short := time.Duration(20) * time.Millisecond // delay waits for cnt durations, an a little extra delay := func(cnt int) time.Duration { - return time.Duration(cnt)*dur + time.Millisecond + return time.Duration(cnt)*dur + time.Duration(5)*time.Millisecond } t := NewRepeatTimer("bar", dur) diff --git a/common/throttle_timer.go b/common/throttle_timer.go index 38ef4e9a3..ab2ad2e62 100644 --- a/common/throttle_timer.go +++ b/common/throttle_timer.go @@ -1,7 +1,6 @@ package common import ( - "sync" "time" ) @@ -12,64 +11,117 @@ If a long continuous burst of .Set() calls happens, ThrottleTimer fires at most once every "dur". */ type ThrottleTimer struct { - Name string - Ch chan struct{} - quit chan struct{} - dur time.Duration + Name string + Ch <-chan struct{} + input chan command + output chan<- struct{} + dur time.Duration - mtx sync.Mutex timer *time.Timer isSet bool } +type command int32 + +const ( + Set command = iota + Unset + Quit +) + +// NewThrottleTimer creates a new ThrottleTimer. func NewThrottleTimer(name string, dur time.Duration) *ThrottleTimer { - var ch = make(chan struct{}) - var quit = make(chan struct{}) - var t = &ThrottleTimer{Name: name, Ch: ch, dur: dur, quit: quit} - t.mtx.Lock() - t.timer = time.AfterFunc(dur, t.fireRoutine) - t.mtx.Unlock() + c := make(chan struct{}) + var t = &ThrottleTimer{ + Name: name, + Ch: c, + dur: dur, + input: make(chan command), + output: c, + timer: time.NewTimer(dur), + } t.timer.Stop() + go t.run() return t } -func (t *ThrottleTimer) fireRoutine() { - t.mtx.Lock() - defer t.mtx.Unlock() +func (t *ThrottleTimer) run() { + for { + select { + case cmd := <-t.input: + // stop goroutine if the input says so + // don't close channels, as closed channels mess up select reads + if t.processInput(cmd) { + return + } + case <-t.timer.C: + t.trySend() + } + } +} + +// trySend performs non-blocking send on t.Ch +func (t *ThrottleTimer) trySend() { select { - case t.Ch <- struct{}{}: + case t.output <- struct{}{}: t.isSet = false - case <-t.quit: - // do nothing default: + // if we just want to drop, replace this with t.isSet = false t.timer.Reset(t.dur) } } -func (t *ThrottleTimer) Set() { - t.mtx.Lock() - defer t.mtx.Unlock() - if !t.isSet { - t.isSet = true - t.timer.Reset(t.dur) +// all modifications of the internal state of ThrottleTimer +// happen in this method. It is only called from the run goroutine +// so we avoid any race conditions +func (t *ThrottleTimer) processInput(cmd command) (shutdown bool) { + switch cmd { + case Set: + if !t.isSet { + t.isSet = true + t.timer.Reset(t.dur) + } + case Quit: + shutdown = true + fallthrough + case Unset: + if t.isSet { + t.isSet = false + t.timer.Stop() + } + default: + panic("unknown command!") } + return shutdown +} + +func (t *ThrottleTimer) Set() { + t.input <- Set } func (t *ThrottleTimer) Unset() { - t.mtx.Lock() - defer t.mtx.Unlock() - t.isSet = false - t.timer.Stop() + t.input <- Unset } -// For ease of .Stop()'ing services before .Start()'ing them, -// we ignore .Stop()'s on nil ThrottleTimers +// Stop prevents the ThrottleTimer from firing. It always returns true. Stop does not +// close the channel, to prevent a read from the channel succeeding +// incorrectly. +// +// To prevent a timer created with NewThrottleTimer from firing after a call to +// Stop, check the return value and drain the channel. +// +// For example, assuming the program has not received from t.C already: +// +// if !t.Stop() { +// <-t.C +// } +// +// For ease of stopping services before starting them, we ignore Stop on nil +// ThrottleTimers. func (t *ThrottleTimer) Stop() bool { if t == nil { return false } - close(t.quit) - t.mtx.Lock() - defer t.mtx.Unlock() - return t.timer.Stop() + t.input <- Quit + return true } diff --git a/common/throttle_timer_test.go b/common/throttle_timer_test.go index 00f5abdec..a1b6606f5 100644 --- a/common/throttle_timer_test.go +++ b/common/throttle_timer_test.go @@ -10,7 +10,7 @@ import ( ) type thCounter struct { - input chan struct{} + input <-chan struct{} mtx sync.Mutex count int } @@ -31,6 +31,9 @@ func (c *thCounter) Count() int { // Read should run in a go-routine and // updates count by one every time a packet comes in func (c *thCounter) Read() { + // note, since this channel never closes, this will never end + // if thCounter was used in anything beyond trivial test cases. + // it would have to be smarter. for range c.input { c.Increment() } @@ -41,6 +44,7 @@ func TestThrottle(test *testing.T) { ms := 50 delay := time.Duration(ms) * time.Millisecond + shortwait := time.Duration(ms/2) * time.Millisecond longwait := time.Duration(2) * delay t := NewThrottleTimer("foo", delay) @@ -65,6 +69,21 @@ func TestThrottle(test *testing.T) { time.Sleep(longwait) assert.Equal(2, c.Count()) + // keep cancelling before it is ready + for i := 0; i < 10; i++ { + t.Set() + time.Sleep(shortwait) + t.Unset() + } + time.Sleep(longwait) + assert.Equal(2, c.Count()) + + // a few unsets do nothing... + for i := 0; i < 5; i++ { + t.Unset() + } + assert.Equal(2, c.Count()) + // send 12, over 2 delay sections, adds 3 short := time.Duration(ms/5) * time.Millisecond for i := 0; i < 13; i++ { @@ -74,5 +93,6 @@ func TestThrottle(test *testing.T) { time.Sleep(longwait) assert.Equal(5, c.Count()) - close(t.Ch) + stopped := t.Stop() + assert.True(stopped) }