diff --git a/common/throttle_timer.go b/common/throttle_timer.go index 38ef4e9a3..e260e01bd 100644 --- a/common/throttle_timer.go +++ b/common/throttle_timer.go @@ -1,7 +1,7 @@ package common import ( - "sync" + "fmt" "time" ) @@ -12,54 +12,88 @@ If a long continuous burst of .Set() calls happens, ThrottleTimer fires at most once every "dur". */ type ThrottleTimer struct { - Name string - Ch chan struct{} - quit chan struct{} - dur time.Duration + Name string + Ch chan struct{} + input chan command + dur time.Duration - mtx sync.Mutex timer *time.Timer isSet bool } +type command int32 + +const ( + Set command = iota + Unset + Quit +) + func NewThrottleTimer(name string, dur time.Duration) *ThrottleTimer { - var ch = make(chan struct{}) - var quit = make(chan struct{}) - var t = &ThrottleTimer{Name: name, Ch: ch, dur: dur, quit: quit} - t.mtx.Lock() - t.timer = time.AfterFunc(dur, t.fireRoutine) - t.mtx.Unlock() + var t = &ThrottleTimer{ + Name: name, + Ch: make(chan struct{}, 1), + dur: dur, + input: make(chan command), + timer: time.NewTimer(dur), + } t.timer.Stop() + go t.run() return t } -func (t *ThrottleTimer) fireRoutine() { - t.mtx.Lock() - defer t.mtx.Unlock() - select { - case t.Ch <- struct{}{}: - t.isSet = false - case <-t.quit: - // do nothing +func (t *ThrottleTimer) run() { + for { + select { + case cmd := <-t.input: + // stop goroutine if the input says so + if t.processInput(cmd) { + // TODO: do we want to close the channels??? + // close(t.Ch) + // close(t.input) + return + } + case <-t.timer.C: + t.isSet = false + t.Ch <- struct{}{} + } + } +} + +// all modifications of the internal state of ThrottleTimer +// happen in this method. It is only called from the run goroutine +// so we avoid any race conditions +func (t *ThrottleTimer) processInput(cmd command) (shutdown bool) { + fmt.Printf("processInput: %d\n", cmd) + switch cmd { + case Set: + if !t.isSet { + t.isSet = true + t.timer.Reset(t.dur) + } + case Quit: + shutdown = true + fallthrough + case Unset: + if t.isSet { + t.isSet = false + if !t.timer.Stop() { + <-t.timer.C + } + } default: - t.timer.Reset(t.dur) + panic("unknown command!") } + // return true + return shutdown } func (t *ThrottleTimer) Set() { - t.mtx.Lock() - defer t.mtx.Unlock() - if !t.isSet { - t.isSet = true - t.timer.Reset(t.dur) - } + t.input <- Set } func (t *ThrottleTimer) Unset() { - t.mtx.Lock() - defer t.mtx.Unlock() - t.isSet = false - t.timer.Stop() + t.input <- Unset } // For ease of .Stop()'ing services before .Start()'ing them, @@ -68,8 +102,6 @@ func (t *ThrottleTimer) Stop() bool { if t == nil { return false } - close(t.quit) - t.mtx.Lock() - defer t.mtx.Unlock() - return t.timer.Stop() + t.input <- Quit + return true } diff --git a/common/throttle_timer_test.go b/common/throttle_timer_test.go index 00f5abdec..014f9dcdc 100644 --- a/common/throttle_timer_test.go +++ b/common/throttle_timer_test.go @@ -41,6 +41,7 @@ func TestThrottle(test *testing.T) { ms := 50 delay := time.Duration(ms) * time.Millisecond + shortwait := time.Duration(ms/2) * time.Millisecond longwait := time.Duration(2) * delay t := NewThrottleTimer("foo", delay) @@ -65,6 +66,21 @@ func TestThrottle(test *testing.T) { time.Sleep(longwait) assert.Equal(2, c.Count()) + // keep cancelling before it is ready + for i := 0; i < 10; i++ { + t.Set() + time.Sleep(shortwait) + t.Unset() + } + time.Sleep(longwait) + assert.Equal(2, c.Count()) + + // a few unsets do nothing... + for i := 0; i < 5; i++ { + t.Unset() + } + assert.Equal(2, c.Count()) + // send 12, over 2 delay sections, adds 3 short := time.Duration(ms/5) * time.Millisecond for i := 0; i < 13; i++ { @@ -74,5 +90,6 @@ func TestThrottle(test *testing.T) { time.Sleep(longwait) assert.Equal(5, c.Count()) - close(t.Ch) + stopped := t.Stop() + assert.True(stopped) }