diff --git a/CHANGELOG.md b/CHANGELOG.md index 69026e113..7a55fb554 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,37 @@ # Changelog +## 0.8.1 (April 5th, 2018) + +FEATURES: + + - [common] Error.Error() includes cause + - [common] IsEmpty() for 0 length + +## 0.8.0 (April 4th, 2018) + +BREAKING: + + - [merkle] `PutVarint->PutUvarint` in encodeByteSlice + - [db] batch.WriteSync() + - [common] Refactored and fixed `Parallel` function + - [common] Refactored `Rand` functionality + - [common] Remove unused `Right/LeftPadString` functions + - [common] Remove StackError, introduce Error interface (to replace use of pkg/errors) + +FEATURES: + + - [db] NewPrefixDB for a DB with all keys prefixed + - [db] NewDebugDB prints everything during operation + - [common] SplitAndTrim func + - [common] rand.Float64(), rand.Int63n(n), rand.Int31n(n) and global equivalents + - [common] HexBytes Format() + +BUG FIXES: + + - [pubsub] Fix unsubscribing + - [cli] Return config errors + - [common] Fix WriteFileAtomic Windows bug + ## 0.7.1 (March 22, 2018) IMPROVEMENTS: diff --git a/autofile/autofile.go b/autofile/autofile.go index 05fb0d677..790be5224 100644 --- a/autofile/autofile.go +++ b/autofile/autofile.go @@ -5,7 +5,7 @@ import ( "sync" "time" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) /* AutoFile usage @@ -44,7 +44,7 @@ type AutoFile struct { func OpenAutoFile(path string) (af *AutoFile, err error) { af = &AutoFile{ - ID: RandStr(12) + ":" + path, + ID: cmn.RandStr(12) + ":" + path, Path: path, ticker: time.NewTicker(autoFileOpenDuration), } @@ -129,9 +129,8 @@ func (af *AutoFile) Size() (int64, error) { if err != nil { if err == os.ErrNotExist { return 0, nil - } else { - return -1, err } + return -1, err } } stat, err := af.file.Stat() diff --git a/autofile/group.go b/autofile/group.go index f2d0f2bae..652c33310 100644 --- a/autofile/group.go +++ b/autofile/group.go @@ -15,7 +15,7 @@ import ( "sync" "time" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) const ( @@ -54,7 +54,7 @@ The Group can also be used to binary-search for some line, assuming that marker lines are written occasionally. */ type Group struct { - BaseService + cmn.BaseService ID string Head *AutoFile // The head AutoFile to write to @@ -90,7 +90,7 @@ func OpenGroup(headPath string) (g *Group, err error) { minIndex: 0, maxIndex: 0, } - g.BaseService = *NewBaseService(nil, "Group", g) + g.BaseService = *cmn.NewBaseService(nil, "Group", g) gInfo := g.readGroupInfo() g.minIndex = gInfo.MinIndex @@ -267,7 +267,7 @@ func (g *Group) RotateFile() { panic(err) } - g.maxIndex += 1 + g.maxIndex++ } // NewReader returns a new group reader. @@ -277,9 +277,8 @@ func (g *Group) NewReader(index int) (*GroupReader, error) { err := r.SetIndex(index) if err != nil { return nil, err - } else { - return r, nil } + return r, nil } // Returns -1 if line comes after, 0 if found, 1 if line comes before. @@ -311,9 +310,8 @@ func (g *Group) Search(prefix string, cmp SearchFunc) (*GroupReader, bool, error if err != nil { r.Close() return nil, false, err - } else { - return r, match, err } + return r, match, err } // Read starting roughly at the middle file, @@ -349,9 +347,8 @@ func (g *Group) Search(prefix string, cmp SearchFunc) (*GroupReader, bool, error if err != nil { r.Close() return nil, false, err - } else { - return r, true, err } + return r, true, err } else { // We passed it maxIndex = curIndex - 1 @@ -429,9 +426,8 @@ GROUP_LOOP: if err == io.EOF { if found { return match, found, nil - } else { - continue GROUP_LOOP } + continue GROUP_LOOP } else if err != nil { return "", false, err } @@ -442,9 +438,8 @@ GROUP_LOOP: if r.CurIndex() > i { if found { return match, found, nil - } else { - continue GROUP_LOOP } + continue GROUP_LOOP } } } @@ -520,7 +515,7 @@ func (g *Group) readGroupInfo() GroupInfo { minIndex, maxIndex = 0, 0 } else { // Otherwise, the head file is 1 greater - maxIndex += 1 + maxIndex++ } return GroupInfo{minIndex, maxIndex, totalSize, headSize} } @@ -528,9 +523,8 @@ func (g *Group) readGroupInfo() GroupInfo { func filePathForIndex(headPath string, index int, maxIndex int) string { if index == maxIndex { return headPath - } else { - return fmt.Sprintf("%v.%03d", headPath, index) } + return fmt.Sprintf("%v.%03d", headPath, index) } //-------------------------------------------------------------------------------- @@ -567,9 +561,8 @@ func (gr *GroupReader) Close() error { gr.curFile = nil gr.curLine = nil return err - } else { - return nil } + return nil } // Read implements io.Reader, reading bytes from the current Reader @@ -598,10 +591,10 @@ func (gr *GroupReader) Read(p []byte) (n int, err error) { if err == io.EOF { if n >= lenP { return n, nil - } else { // Open the next file - if err1 := gr.openFile(gr.curIndex + 1); err1 != nil { - return n, err1 - } + } + // Open the next file + if err1 := gr.openFile(gr.curIndex + 1); err1 != nil { + return n, err1 } } else if err != nil { return n, err @@ -643,10 +636,9 @@ func (gr *GroupReader) ReadLine() (string, error) { } if len(bytesRead) > 0 && bytesRead[len(bytesRead)-1] == byte('\n') { return linePrefix + string(bytesRead[:len(bytesRead)-1]), nil - } else { - linePrefix += string(bytesRead) - continue } + linePrefix += string(bytesRead) + continue } else if err != nil { return "", err } @@ -726,11 +718,11 @@ func (gr *GroupReader) SetIndex(index int) error { func MakeSimpleSearchFunc(prefix string, target int) SearchFunc { return func(line string) (int, error) { if !strings.HasPrefix(line, prefix) { - return -1, errors.New(Fmt("Marker line did not have prefix: %v", prefix)) + return -1, errors.New(cmn.Fmt("Marker line did not have prefix: %v", prefix)) } i, err := strconv.Atoi(line[len(prefix):]) if err != nil { - return -1, errors.New(Fmt("Failed to parse marker line: %v", err.Error())) + return -1, errors.New(cmn.Fmt("Failed to parse marker line: %v", err.Error())) } if target < i { return 1, nil diff --git a/autofile/group_test.go b/autofile/group_test.go index c4f68f057..1a1111961 100644 --- a/autofile/group_test.go +++ b/autofile/group_test.go @@ -175,7 +175,7 @@ func TestSearch(t *testing.T) { if !strings.HasPrefix(line, fmt.Sprintf("INFO %v ", cur)) { t.Fatalf("Unexpected INFO #. Expected %v got:\n%v", cur, line) } - cur += 1 + cur++ } gr.Close() } diff --git a/cli/setup.go b/cli/setup.go index dc34abdf9..06cf1cd1f 100644 --- a/cli/setup.go +++ b/cli/setup.go @@ -139,9 +139,8 @@ func bindFlagsLoadViper(cmd *cobra.Command, args []string) error { // stderr, so if we redirect output to json file, this doesn't appear // fmt.Fprintln(os.Stderr, "Using config file:", viper.ConfigFileUsed()) } else if _, ok := err.(viper.ConfigFileNotFoundError); !ok { - // we ignore not found error, only parse error - // stderr, so if we redirect output to json file, this doesn't appear - fmt.Fprintf(os.Stderr, "%#v", err) + // ignore not found error, return other errors + return err } return nil } diff --git a/clist/clist.go b/clist/clist.go index 28d771a28..ccb1f5777 100644 --- a/clist/clist.go +++ b/clist/clist.go @@ -316,7 +316,7 @@ func (l *CList) PushBack(v interface{}) *CElement { l.wg.Done() close(l.waitCh) } - l.len += 1 + l.len++ // Modify the tail if l.tail == nil { @@ -357,7 +357,7 @@ func (l *CList) Remove(e *CElement) interface{} { } // Update l.len - l.len -= 1 + l.len-- // Connect next/prev and set head/tail if prev == nil { diff --git a/clist/clist_test.go b/clist/clist_test.go index 31f821653..6171f1a39 100644 --- a/clist/clist_test.go +++ b/clist/clist_test.go @@ -122,7 +122,7 @@ func _TestGCRandom(t *testing.T) { v.Int = i l.PushBack(v) runtime.SetFinalizer(v, func(v *value) { - gcCount += 1 + gcCount++ }) } @@ -177,10 +177,10 @@ func TestScanRightDeleteRandom(t *testing.T) { } if el == nil { el = l.FrontWait() - restartCounter += 1 + restartCounter++ } el = el.Next() - counter += 1 + counter++ } fmt.Printf("Scanner %v restartCounter: %v counter: %v\n", scannerID, restartCounter, counter) }(i) diff --git a/common/async.go b/common/async.go index 1d302c344..49714d95e 100644 --- a/common/async.go +++ b/common/async.go @@ -1,15 +1,148 @@ package common -import "sync" - -func Parallel(tasks ...func()) { - var wg sync.WaitGroup - wg.Add(len(tasks)) - for _, task := range tasks { - go func(task func()) { - task() - wg.Done() - }(task) - } - wg.Wait() +import ( + "sync/atomic" +) + +//---------------------------------------- +// Task + +// val: the value returned after task execution. +// err: the error returned during task completion. +// abort: tells Parallel to return, whether or not all tasks have completed. +type Task func(i int) (val interface{}, err error, abort bool) + +type TaskResult struct { + Value interface{} + Error error +} + +type TaskResultCh <-chan TaskResult + +type taskResultOK struct { + TaskResult + OK bool +} + +type TaskResultSet struct { + chz []TaskResultCh + results []taskResultOK +} + +func newTaskResultSet(chz []TaskResultCh) *TaskResultSet { + return &TaskResultSet{ + chz: chz, + results: nil, + } +} + +func (trs *TaskResultSet) Channels() []TaskResultCh { + return trs.chz +} + +func (trs *TaskResultSet) LatestResult(index int) (TaskResult, bool) { + if len(trs.results) <= index { + return TaskResult{}, false + } + resultOK := trs.results[index] + return resultOK.TaskResult, resultOK.OK +} + +// NOTE: Not concurrency safe. +func (trs *TaskResultSet) Reap() *TaskResultSet { + if trs.results == nil { + trs.results = make([]taskResultOK, len(trs.chz)) + } + for i := 0; i < len(trs.results); i++ { + var trch = trs.chz[i] + select { + case result := <-trch: + // Overwrite result. + trs.results[i] = taskResultOK{ + TaskResult: result, + OK: true, + } + default: + // Do nothing. + } + } + return trs +} + +// Returns the firstmost (by task index) error as +// discovered by all previous Reap() calls. +func (trs *TaskResultSet) FirstValue() interface{} { + for _, result := range trs.results { + if result.Value != nil { + return result.Value + } + } + return nil +} + +// Returns the firstmost (by task index) error as +// discovered by all previous Reap() calls. +func (trs *TaskResultSet) FirstError() error { + for _, result := range trs.results { + if result.Error != nil { + return result.Error + } + } + return nil +} + +//---------------------------------------- +// Parallel + +// Run tasks in parallel, with ability to abort early. +// Returns ok=false iff any of the tasks returned abort=true. +// NOTE: Do not implement quit features here. Instead, provide convenient +// concurrent quit-like primitives, passed implicitly via Task closures. (e.g. +// it's not Parallel's concern how you quit/abort your tasks). +func Parallel(tasks ...Task) (trs *TaskResultSet, ok bool) { + var taskResultChz = make([]TaskResultCh, len(tasks)) // To return. + var taskDoneCh = make(chan bool, len(tasks)) // A "wait group" channel, early abort if any true received. + var numPanics = new(int32) // Keep track of panics to set ok=false later. + ok = true // We will set it to false iff any tasks panic'd or returned abort. + + // Start all tasks in parallel in separate goroutines. + // When the task is complete, it will appear in the + // respective taskResultCh (associated by task index). + for i, task := range tasks { + var taskResultCh = make(chan TaskResult, 1) // Capacity for 1 result. + taskResultChz[i] = taskResultCh + go func(i int, task Task, taskResultCh chan TaskResult) { + // Recovery + defer func() { + if pnk := recover(); pnk != nil { + atomic.AddInt32(numPanics, 1) + taskResultCh <- TaskResult{nil, ErrorWrap(pnk, "Panic in task")} + taskDoneCh <- false + } + }() + // Run the task. + var val, err, abort = task(i) + // Send val/err to taskResultCh. + // NOTE: Below this line, nothing must panic/ + taskResultCh <- TaskResult{val, err} + // Decrement waitgroup. + taskDoneCh <- abort + }(i, task, taskResultCh) + } + + // Wait until all tasks are done, or until abort. + // DONE_LOOP: + for i := 0; i < len(tasks); i++ { + abort := <-taskDoneCh + if abort { + ok = false + break + } + } + + // Ok is also false if there were any panics. + // We must do this check here (after DONE_LOOP). + ok = ok && (atomic.LoadInt32(numPanics) == 0) + + return newTaskResultSet(taskResultChz).Reap(), ok } diff --git a/common/async_test.go b/common/async_test.go new file mode 100644 index 000000000..9f060ca2d --- /dev/null +++ b/common/async_test.go @@ -0,0 +1,152 @@ +package common + +import ( + "errors" + "fmt" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParallel(t *testing.T) { + + // Create tasks. + var counter = new(int32) + var tasks = make([]Task, 100*1000) + for i := 0; i < len(tasks); i++ { + tasks[i] = func(i int) (res interface{}, err error, abort bool) { + atomic.AddInt32(counter, 1) + return -1 * i, nil, false + } + } + + // Run in parallel. + var trs, ok = Parallel(tasks...) + assert.True(t, ok) + + // Verify. + assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") + var failedTasks int + for i := 0; i < len(tasks); i++ { + taskResult, ok := trs.LatestResult(i) + if !ok { + assert.Fail(t, "Task #%v did not complete.", i) + failedTasks++ + } else if taskResult.Error != nil { + assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) + failedTasks++ + } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { + assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) + failedTasks++ + } else { + // Good! + } + } + assert.Equal(t, failedTasks, 0, "No task should have failed") + assert.Nil(t, trs.FirstError(), "There should be no errors") + assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") +} + +func TestParallelAbort(t *testing.T) { + + var flow1 = make(chan struct{}, 1) + var flow2 = make(chan struct{}, 1) + var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. + var flow4 = make(chan struct{}, 1) + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 0) + flow1 <- struct{}{} + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 1) + flow2 <- <-flow1 + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 2) + flow3 <- <-flow2 + return 2, nil, true + }, + func(i int) (res interface{}, err error, abort bool) { + assert.Equal(t, i, 3) + <-flow4 + return 3, nil, false + }, + } + + // Run in parallel. + var taskResultSet, ok = Parallel(tasks...) + assert.False(t, ok, "ok should be false since we aborted task #2.") + + // Verify task #3. + // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send. + waitTimeout(t, taskResultSet.chz[3], "Task #3") + + // Now let the last task (#3) complete after abort. + flow4 <- <-flow3 + + // Verify task #0, #1, #2. + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, 2, nil, nil) +} + +func TestParallelRecover(t *testing.T) { + + // Create tasks. + var tasks = []Task{ + func(i int) (res interface{}, err error, abort bool) { + return 0, nil, false + }, + func(i int) (res interface{}, err error, abort bool) { + return 1, errors.New("some error"), false + }, + func(i int) (res interface{}, err error, abort bool) { + panic(2) + }, + } + + // Run in parallel. + var taskResultSet, ok = Parallel(tasks...) + assert.False(t, ok, "ok should be false since we panic'd in task #2.") + + // Verify task #0, #1, #2. + checkResult(t, taskResultSet, 0, 0, nil, nil) + checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) + checkResult(t, taskResultSet, 2, nil, nil, 2) +} + +// Wait for result +func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { + taskResult, ok := taskResultSet.LatestResult(index) + taskName := fmt.Sprintf("Task #%v", index) + assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) + assert.Equal(t, val, taskResult.Value, taskName) + if err != nil { + assert.Equal(t, err, taskResult.Error, taskName) + } else if pnk != nil { + assert.Equal(t, pnk, taskResult.Error.(Error).Cause(), taskName) + } else { + assert.Nil(t, taskResult.Error, taskName) + } +} + +// Wait for timeout (no result) +func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { + select { + case _, ok := <-taskResultCh: + if !ok { + assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) + } else { + assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) + } + case <-time.After(1 * time.Second): // TODO use deterministic time? + // Good! + } +} diff --git a/common/bit_array.go b/common/bit_array.go index a3a87ccab..ea6a6ee1f 100644 --- a/common/bit_array.go +++ b/common/bit_array.go @@ -168,9 +168,8 @@ func (bA *BitArray) Sub(o *BitArray) *BitArray { } } return c - } else { - return bA.and(o.Not()) // Note degenerate case where o == nil } + return bA.and(o.Not()) // Note degenerate case where o == nil } func (bA *BitArray) IsEmpty() bool { diff --git a/common/bytes.go b/common/bytes.go index ba81bbe97..711720aa7 100644 --- a/common/bytes.go +++ b/common/bytes.go @@ -51,3 +51,12 @@ func (bz HexBytes) Bytes() []byte { func (bz HexBytes) String() string { return strings.ToUpper(hex.EncodeToString(bz)) } + +func (bz HexBytes) Format(s fmt.State, verb rune) { + switch verb { + case 'p': + s.Write([]byte(fmt.Sprintf("%p", bz))) + default: + s.Write([]byte(fmt.Sprintf("%X", []byte(bz)))) + } +} diff --git a/common/colors.go b/common/colors.go index 776b22e2e..85e592248 100644 --- a/common/colors.go +++ b/common/colors.go @@ -38,9 +38,8 @@ const ( func treat(s string, color string) string { if len(s) > 2 && s[:2] == "\x1b[" { return s - } else { - return color + s + ANSIReset } + return color + s + ANSIReset } func treatAll(color string, args ...interface{}) string { diff --git a/common/errors.go b/common/errors.go index 4710b9ee0..1ee1fb349 100644 --- a/common/errors.go +++ b/common/errors.go @@ -2,23 +2,239 @@ package common import ( "fmt" + "runtime" ) -type StackError struct { - Err interface{} - Stack []byte +//---------------------------------------- +// Convenience methods + +// ErrorWrap will just call .TraceFrom(), or create a new *cmnError. +func ErrorWrap(cause interface{}, format string, args ...interface{}) Error { + msg := Fmt(format, args...) + if causeCmnError, ok := cause.(*cmnError); ok { + return causeCmnError.TraceFrom(1, msg) + } + // NOTE: cause may be nil. + // NOTE: do not use causeCmnError here, not the same as nil. + return newError(msg, cause, cause).Stacktrace() +} + +//---------------------------------------- +// Error & cmnError + +/* +Usage: + +```go + // Error construction + var someT = errors.New("Some err type") + var err1 error = NewErrorWithT(someT, "my message") + ... + // Wrapping + var err2 error = ErrorWrap(err1, "another message") + if (err1 != err2) { panic("should be the same") + ... + // Error handling + switch err2.T() { + case someT: ... + default: ... + } +``` + +*/ +type Error interface { + Error() string + Message() string + Stacktrace() Error + Trace(format string, args ...interface{}) Error + TraceFrom(offset int, format string, args ...interface{}) Error + Cause() interface{} + WithT(t interface{}) Error + T() interface{} + Format(s fmt.State, verb rune) +} + +// New Error with no cause where the type is the format string of the message.. +func NewError(format string, args ...interface{}) Error { + msg := Fmt(format, args...) + return newError(msg, nil, format) + +} + +// New Error with specified type and message. +func NewErrorWithT(t interface{}, format string, args ...interface{}) Error { + msg := Fmt(format, args...) + return newError(msg, nil, t) +} + +// NOTE: The name of a function "NewErrorWithCause()" implies that you are +// creating a new Error, yet, if the cause is an Error, creating a new Error to +// hold a ref to the old Error is probably *not* what you want to do. +// So, use ErrorWrap(cause, format, a...) instead, which returns the same error +// if cause is an Error. +// IF you must set an Error as the cause of an Error, +// then you can use the WithCauser interface to do so manually. +// e.g. (error).(tmlibs.WithCauser).WithCause(causeError) + +type WithCauser interface { + WithCause(cause interface{}) Error +} + +type cmnError struct { + msg string // first msg which also appears in msg + cause interface{} // underlying cause (or panic object) + t interface{} // for switching on error + msgtraces []msgtraceItem // all messages traced + stacktrace []uintptr // first stack trace +} + +var _ WithCauser = &cmnError{} +var _ Error = &cmnError{} + +// NOTE: do not expose. +func newError(msg string, cause interface{}, t interface{}) *cmnError { + return &cmnError{ + msg: msg, + cause: cause, + t: t, + msgtraces: nil, + stacktrace: nil, + } +} + +func (err *cmnError) Message() string { + return err.msg +} + +func (err *cmnError) Error() string { + return fmt.Sprintf("%v", err) } -func (se StackError) String() string { - return fmt.Sprintf("Error: %v\nStack: %s", se.Err, se.Stack) +// Captures a stacktrace if one was not already captured. +func (err *cmnError) Stacktrace() Error { + if err.stacktrace == nil { + var offset = 3 + var depth = 32 + err.stacktrace = captureStacktrace(offset, depth) + } + return err } -func (se StackError) Error() string { - return se.String() +// Add tracing information with msg. +func (err *cmnError) Trace(format string, args ...interface{}) Error { + msg := Fmt(format, args...) + return err.doTrace(msg, 0) } -//-------------------------------------------------------------------------------------------------- -// panic wrappers +// Same as Trace, but traces the line `offset` calls out. +// If n == 0, the behavior is identical to Trace(). +func (err *cmnError) TraceFrom(offset int, format string, args ...interface{}) Error { + msg := Fmt(format, args...) + return err.doTrace(msg, offset) +} + +// Return last known cause. +// NOTE: The meaning of "cause" is left for the caller to define. +// There exists no "canonical" definition of "cause". +// Instead of blaming, try to handle it, or organize it. +func (err *cmnError) Cause() interface{} { + return err.cause +} + +// Overwrites the Error's cause. +func (err *cmnError) WithCause(cause interface{}) Error { + err.cause = cause + return err +} + +// Overwrites the Error's type. +func (err *cmnError) WithT(t interface{}) Error { + err.t = t + return err +} + +// Return the "type" of this message, primarily for switching +// to handle this Error. +func (err *cmnError) T() interface{} { + return err.t +} + +func (err *cmnError) doTrace(msg string, n int) Error { + pc, _, _, _ := runtime.Caller(n + 2) // +1 for doTrace(). +1 for the caller. + // Include file & line number & msg. + // Do not include the whole stack trace. + err.msgtraces = append(err.msgtraces, msgtraceItem{ + pc: pc, + msg: msg, + }) + return err +} + +func (err *cmnError) Format(s fmt.State, verb rune) { + switch verb { + case 'p': + s.Write([]byte(fmt.Sprintf("%p", &err))) + default: + if s.Flag('#') { + s.Write([]byte("--= Error =--\n")) + // Write msg. + s.Write([]byte(fmt.Sprintf("Message: %#s\n", err.msg))) + // Write cause. + s.Write([]byte(fmt.Sprintf("Cause: %#v\n", err.cause))) + // Write type. + s.Write([]byte(fmt.Sprintf("T: %#v\n", err.t))) + // Write msg trace items. + s.Write([]byte(fmt.Sprintf("Msg Traces:\n"))) + for i, msgtrace := range err.msgtraces { + s.Write([]byte(fmt.Sprintf(" %4d %s\n", i, msgtrace.String()))) + } + // Write stack trace. + if err.stacktrace != nil { + s.Write([]byte(fmt.Sprintf("Stack Trace:\n"))) + for i, pc := range err.stacktrace { + fnc := runtime.FuncForPC(pc) + file, line := fnc.FileLine(pc) + s.Write([]byte(fmt.Sprintf(" %4d %s:%d\n", i, file, line))) + } + } + s.Write([]byte("--= /Error =--\n")) + } else { + // Write msg. + if err.cause != nil { + s.Write([]byte(fmt.Sprintf("Error{`%s` (cause: %v)}", err.msg, err.cause))) // TODO tick-esc? + } else { + s.Write([]byte(fmt.Sprintf("Error{`%s`}", err.msg))) // TODO tick-esc? + } + } + } +} + +//---------------------------------------- +// stacktrace & msgtraceItem + +func captureStacktrace(offset int, depth int) []uintptr { + var pcs = make([]uintptr, depth) + n := runtime.Callers(offset, pcs) + return pcs[0:n] +} + +type msgtraceItem struct { + pc uintptr + msg string +} + +func (mti msgtraceItem) String() string { + fnc := runtime.FuncForPC(mti.pc) + file, line := fnc.FileLine(mti.pc) + return fmt.Sprintf("%s:%d - %s", + file, line, + mti.msg, + ) +} + +//---------------------------------------- +// Panic wrappers +// XXX DEPRECATED // A panic resulting from a sanity check means there is a programmer error // and some guarantee is not satisfied. diff --git a/common/errors_test.go b/common/errors_test.go new file mode 100644 index 000000000..2c5234f9f --- /dev/null +++ b/common/errors_test.go @@ -0,0 +1,107 @@ +package common + +import ( + fmt "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorPanic(t *testing.T) { + type pnk struct { + msg string + } + + capturePanic := func() (err Error) { + defer func() { + if r := recover(); r != nil { + err = ErrorWrap(r, "This is the message in ErrorWrap(r, message).") + } + return + }() + panic(pnk{"something"}) + return nil + } + + var err = capturePanic() + + assert.Equal(t, pnk{"something"}, err.Cause()) + assert.Equal(t, pnk{"something"}, err.T()) + assert.Equal(t, "This is the message in ErrorWrap(r, message).", err.Message()) + assert.Equal(t, "Error{`This is the message in ErrorWrap(r, message).` (cause: {something})}", fmt.Sprintf("%v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "Message: This is the message in ErrorWrap(r, message).") + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorWrapSomething(t *testing.T) { + + var err = ErrorWrap("something", "formatter%v%v", 0, 1) + + assert.Equal(t, "something", err.Cause()) + assert.Equal(t, "something", err.T()) + assert.Equal(t, "formatter01", err.Message()) + assert.Equal(t, "Error{`formatter01` (cause: something)}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorWrapNothing(t *testing.T) { + + var err = ErrorWrap(nil, "formatter%v%v", 0, 1) + + assert.Equal(t, nil, err.Cause()) + assert.Equal(t, nil, err.T()) + assert.Equal(t, "formatter01", err.Message()) + assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorNewError(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1) + + assert.Equal(t, nil, err.Cause()) + assert.Equal(t, "formatter%v%v", err.T()) + assert.Equal(t, "formatter01", err.Message()) + assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) + assert.NotContains(t, fmt.Sprintf("%#v", err), "Stack Trace") +} + +func TestErrorNewErrorWithStacktrace(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1).Stacktrace() + + assert.Equal(t, nil, err.Cause()) + assert.Equal(t, "formatter%v%v", err.T()) + assert.Equal(t, "formatter01", err.Message()) + assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) + assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") +} + +func TestErrorNewErrorWithTrace(t *testing.T) { + + var err = NewError("formatter%v%v", 0, 1) + err.Trace("trace %v", 1) + err.Trace("trace %v", 2) + err.Trace("trace %v", 3) + + assert.Equal(t, nil, err.Cause()) + assert.Equal(t, "formatter%v%v", err.T()) + assert.Equal(t, "formatter01", err.Message()) + assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) + assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) + dump := fmt.Sprintf("%#v", err) + assert.NotContains(t, dump, "Stack Trace") + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 1`, dump) + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 2`, dump) + assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 3`, dump) +} + +func TestErrorWrapError(t *testing.T) { + var err1 error = NewError("my message") + var err2 error = ErrorWrap(err1, "another message") + assert.Equal(t, err1, err2) +} diff --git a/common/io.go b/common/io.go index 378c19fc6..fa0443e09 100644 --- a/common/io.go +++ b/common/io.go @@ -20,9 +20,8 @@ func (pr *PrefixedReader) Read(p []byte) (n int, err error) { read := copy(p, pr.Prefix) pr.Prefix = pr.Prefix[read:] return read, nil - } else { - return pr.reader.Read(p) } + return pr.reader.Read(p) } // NOTE: Not goroutine safe diff --git a/common/nil.go b/common/nil.go new file mode 100644 index 000000000..31f75f008 --- /dev/null +++ b/common/nil.go @@ -0,0 +1,29 @@ +package common + +import "reflect" + +// Go lacks a simple and safe way to see if something is a typed nil. +// See: +// - https://dave.cheney.net/2017/08/09/typed-nils-in-go-2 +// - https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I/discussion +// - https://github.com/golang/go/issues/21538 +func IsTypedNil(o interface{}) bool { + rv := reflect.ValueOf(o) + switch rv.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice: + return rv.IsNil() + default: + return false + } +} + +// Returns true if it has zero length. +func IsEmpty(o interface{}) bool { + rv := reflect.ValueOf(o) + switch rv.Kind() { + case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: + return rv.Len() == 0 + default: + return false + } +} diff --git a/common/os.go b/common/os.go index f1e07115c..00f4da57b 100644 --- a/common/os.go +++ b/common/os.go @@ -148,6 +148,9 @@ func WriteFileAtomic(filename string, data []byte, perm os.FileMode) error { } else if n < len(data) { return io.ErrShortWrite } + // Close the file before renaming it, otherwise it will cause "The process + // cannot access the file because it is being used by another process." on windows. + f.Close() return os.Rename(f.Name(), filename) } @@ -183,11 +186,10 @@ func Prompt(prompt string, defaultValue string) (string, error) { line, err := reader.ReadString('\n') if err != nil { return defaultValue, err - } else { - line = strings.TrimSpace(line) - if line == "" { - return defaultValue, nil - } - return line, nil } + line = strings.TrimSpace(line) + if line == "" { + return defaultValue, nil + } + return line, nil } diff --git a/common/random.go b/common/random.go index ca71b6143..389a32fc2 100644 --- a/common/random.go +++ b/common/random.go @@ -13,34 +13,150 @@ const ( // pseudo random number generator. // seeded with OS randomness (crand) -var prng struct { + +type Rand struct { sync.Mutex - *mrand.Rand + rand *mrand.Rand +} + +var grand *Rand + +func init() { + grand = NewRand() + grand.init() +} + +func NewRand() *Rand { + rand := &Rand{} + rand.init() + return rand } -func reset() { - b := cRandBytes(8) +func (r *Rand) init() { + bz := cRandBytes(8) var seed uint64 for i := 0; i < 8; i++ { - seed |= uint64(b[i]) + seed |= uint64(bz[i]) seed <<= 8 } - prng.Lock() - prng.Rand = mrand.New(mrand.NewSource(int64(seed))) - prng.Unlock() + r.reset(int64(seed)) } -func init() { - reset() +func (r *Rand) reset(seed int64) { + r.rand = mrand.New(mrand.NewSource(seed)) +} + +//---------------------------------------- +// Global functions + +func Seed(seed int64) { + grand.Seed(seed) +} + +func RandStr(length int) string { + return grand.Str(length) +} + +func RandUint16() uint16 { + return grand.Uint16() +} + +func RandUint32() uint32 { + return grand.Uint32() +} + +func RandUint64() uint64 { + return grand.Uint64() +} + +func RandUint() uint { + return grand.Uint() +} + +func RandInt16() int16 { + return grand.Int16() +} + +func RandInt32() int32 { + return grand.Int32() +} + +func RandInt64() int64 { + return grand.Int64() +} + +func RandInt() int { + return grand.Int() +} + +func RandInt31() int32 { + return grand.Int31() +} + +func RandInt31n(n int32) int32 { + return grand.Int31n(n) +} + +func RandInt63() int64 { + return grand.Int63() +} + +func RandInt63n(n int64) int64 { + return grand.Int63n(n) +} + +func RandUint16Exp() uint16 { + return grand.Uint16Exp() +} + +func RandUint32Exp() uint32 { + return grand.Uint32Exp() +} + +func RandUint64Exp() uint64 { + return grand.Uint64Exp() +} + +func RandFloat32() float32 { + return grand.Float32() +} + +func RandFloat64() float64 { + return grand.Float64() +} + +func RandTime() time.Time { + return grand.Time() +} + +func RandBytes(n int) []byte { + return grand.Bytes(n) +} + +func RandIntn(n int) int { + return grand.Intn(n) +} + +func RandPerm(n int) []int { + return grand.Perm(n) +} + +//---------------------------------------- +// Rand methods + +func (r *Rand) Seed(seed int64) { + r.Lock() + r.reset(seed) + r.Unlock() } // Constructs an alphanumeric string of given length. // It is not safe for cryptographic usage. -func RandStr(length int) string { +func (r *Rand) Str(length int) string { chars := []byte{} MAIN_LOOP: for { - val := RandInt63() + val := r.Int63() for i := 0; i < 10; i++ { v := int(val & 0x3f) // rightmost 6 bits if v >= 62 { // only 62 characters in strChars @@ -60,127 +176,151 @@ MAIN_LOOP: } // It is not safe for cryptographic usage. -func RandUint16() uint16 { - return uint16(RandUint32() & (1<<16 - 1)) +func (r *Rand) Uint16() uint16 { + return uint16(r.Uint32() & (1<<16 - 1)) } // It is not safe for cryptographic usage. -func RandUint32() uint32 { - prng.Lock() - u32 := prng.Uint32() - prng.Unlock() +func (r *Rand) Uint32() uint32 { + r.Lock() + u32 := r.rand.Uint32() + r.Unlock() return u32 } // It is not safe for cryptographic usage. -func RandUint64() uint64 { - return uint64(RandUint32())<<32 + uint64(RandUint32()) +func (r *Rand) Uint64() uint64 { + return uint64(r.Uint32())<<32 + uint64(r.Uint32()) } // It is not safe for cryptographic usage. -func RandUint() uint { - prng.Lock() - i := prng.Int() - prng.Unlock() +func (r *Rand) Uint() uint { + r.Lock() + i := r.rand.Int() + r.Unlock() return uint(i) } // It is not safe for cryptographic usage. -func RandInt16() int16 { - return int16(RandUint32() & (1<<16 - 1)) +func (r *Rand) Int16() int16 { + return int16(r.Uint32() & (1<<16 - 1)) } // It is not safe for cryptographic usage. -func RandInt32() int32 { - return int32(RandUint32()) +func (r *Rand) Int32() int32 { + return int32(r.Uint32()) } // It is not safe for cryptographic usage. -func RandInt64() int64 { - return int64(RandUint64()) +func (r *Rand) Int64() int64 { + return int64(r.Uint64()) } // It is not safe for cryptographic usage. -func RandInt() int { - prng.Lock() - i := prng.Int() - prng.Unlock() +func (r *Rand) Int() int { + r.Lock() + i := r.rand.Int() + r.Unlock() return i } // It is not safe for cryptographic usage. -func RandInt31() int32 { - prng.Lock() - i31 := prng.Int31() - prng.Unlock() +func (r *Rand) Int31() int32 { + r.Lock() + i31 := r.rand.Int31() + r.Unlock() return i31 } // It is not safe for cryptographic usage. -func RandInt63() int64 { - prng.Lock() - i63 := prng.Int63() - prng.Unlock() +func (r *Rand) Int31n(n int32) int32 { + r.Lock() + i31n := r.rand.Int31n(n) + r.Unlock() + return i31n +} + +// It is not safe for cryptographic usage. +func (r *Rand) Int63() int64 { + r.Lock() + i63 := r.rand.Int63() + r.Unlock() return i63 } +// It is not safe for cryptographic usage. +func (r *Rand) Int63n(n int64) int64 { + r.Lock() + i63n := r.rand.Int63n(n) + r.Unlock() + return i63n +} + // Distributed pseudo-exponentially to test for various cases // It is not safe for cryptographic usage. -func RandUint16Exp() uint16 { - bits := RandUint32() % 16 +func (r *Rand) Uint16Exp() uint16 { + bits := r.Uint32() % 16 if bits == 0 { return 0 } n := uint16(1 << (bits - 1)) - n += uint16(RandInt31()) & ((1 << (bits - 1)) - 1) + n += uint16(r.Int31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases // It is not safe for cryptographic usage. -func RandUint32Exp() uint32 { - bits := RandUint32() % 32 +func (r *Rand) Uint32Exp() uint32 { + bits := r.Uint32() % 32 if bits == 0 { return 0 } n := uint32(1 << (bits - 1)) - n += uint32(RandInt31()) & ((1 << (bits - 1)) - 1) + n += uint32(r.Int31()) & ((1 << (bits - 1)) - 1) return n } // Distributed pseudo-exponentially to test for various cases // It is not safe for cryptographic usage. -func RandUint64Exp() uint64 { - bits := RandUint32() % 64 +func (r *Rand) Uint64Exp() uint64 { + bits := r.Uint32() % 64 if bits == 0 { return 0 } n := uint64(1 << (bits - 1)) - n += uint64(RandInt63()) & ((1 << (bits - 1)) - 1) + n += uint64(r.Int63()) & ((1 << (bits - 1)) - 1) return n } // It is not safe for cryptographic usage. -func RandFloat32() float32 { - prng.Lock() - f32 := prng.Float32() - prng.Unlock() +func (r *Rand) Float32() float32 { + r.Lock() + f32 := r.rand.Float32() + r.Unlock() return f32 } // It is not safe for cryptographic usage. -func RandTime() time.Time { - return time.Unix(int64(RandUint64Exp()), 0) +func (r *Rand) Float64() float64 { + r.Lock() + f64 := r.rand.Float64() + r.Unlock() + return f64 +} + +// It is not safe for cryptographic usage. +func (r *Rand) Time() time.Time { + return time.Unix(int64(r.Uint64Exp()), 0) } // RandBytes returns n random bytes from the OS's source of entropy ie. via crypto/rand. // It is not safe for cryptographic usage. -func RandBytes(n int) []byte { +func (r *Rand) Bytes(n int) []byte { // cRandBytes isn't guaranteed to be fast so instead // use random bytes generated from the internal PRNG bs := make([]byte, n) for i := 0; i < len(bs); i++ { - bs[i] = byte(RandInt() & 0xFF) + bs[i] = byte(r.Int() & 0xFF) } return bs } @@ -188,19 +328,19 @@ func RandBytes(n int) []byte { // RandIntn returns, as an int, a non-negative pseudo-random number in [0, n). // It panics if n <= 0. // It is not safe for cryptographic usage. -func RandIntn(n int) int { - prng.Lock() - i := prng.Intn(n) - prng.Unlock() +func (r *Rand) Intn(n int) int { + r.Lock() + i := r.rand.Intn(n) + r.Unlock() return i } // RandPerm returns a pseudo-random permutation of n integers in [0, n). // It is not safe for cryptographic usage. -func RandPerm(n int) []int { - prng.Lock() - perm := prng.Perm(n) - prng.Unlock() +func (r *Rand) Perm(n int) []int { + r.Lock() + perm := r.rand.Perm(n) + r.Unlock() return perm } diff --git a/common/random_test.go b/common/random_test.go index 216f2f8bc..b58b4a13a 100644 --- a/common/random_test.go +++ b/common/random_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "io" mrand "math/rand" "sync" "testing" @@ -33,37 +32,38 @@ func TestRandIntn(t *testing.T) { } } -// It is essential that these tests run and never repeat their outputs -// lest we've been pwned and the behavior of our randomness is controlled. -// See Issues: -// * https://github.com/tendermint/tmlibs/issues/99 -// * https://github.com/tendermint/tendermint/issues/973 -func TestUniqueRng(t *testing.T) { - buf := new(bytes.Buffer) - outputs := make(map[string][]int) +// Test to make sure that we never call math.rand(). +// We do this by ensuring that outputs are deterministic. +func TestDeterminism(t *testing.T) { + var firstOutput string + + // Set math/rand's seed for the sake of debugging this test. + // (It isn't strictly necessary). + mrand.Seed(1) + for i := 0; i < 100; i++ { - testThemAll(buf) - output := buf.String() - buf.Reset() - runs, seen := outputs[output] - if seen { - t.Errorf("Run #%d's output was already seen in previous runs: %v", i, runs) + output := testThemAll() + if i == 0 { + firstOutput = output + } else { + if firstOutput != output { + t.Errorf("Run #%d's output was different from first run.\nfirst: %v\nlast: %v", + i, firstOutput, output) + } } - outputs[output] = append(outputs[output], i) } } -func testThemAll(out io.Writer) { - // Reset the internal PRNG - reset() +func testThemAll() string { - // Set math/rand's Seed so that any direct invocations - // of math/rand will reveal themselves. - mrand.Seed(1) + // Such determinism. + grand.reset(1) + + // Use it. + out := new(bytes.Buffer) perm := RandPerm(10) blob, _ := json.Marshal(perm) fmt.Fprintf(out, "perm: %s\n", blob) - fmt.Fprintf(out, "randInt: %d\n", RandInt()) fmt.Fprintf(out, "randUint: %d\n", RandUint()) fmt.Fprintf(out, "randIntn: %d\n", RandIntn(97)) @@ -76,6 +76,7 @@ func testThemAll(out io.Writer) { fmt.Fprintf(out, "randUint16Exp: %d\n", RandUint16Exp()) fmt.Fprintf(out, "randUint32Exp: %d\n", RandUint32Exp()) fmt.Fprintf(out, "randUint64Exp: %d\n", RandUint64Exp()) + return out.String() } func TestRngConcurrencySafety(t *testing.T) { diff --git a/common/service.go b/common/service.go index 2502d671c..2f90fa4f9 100644 --- a/common/service.go +++ b/common/service.go @@ -125,9 +125,8 @@ func (bs *BaseService) Start() error { if atomic.LoadUint32(&bs.stopped) == 1 { bs.Logger.Error(Fmt("Not starting %v -- already stopped", bs.name), "impl", bs.impl) return ErrAlreadyStopped - } else { - bs.Logger.Info(Fmt("Starting %v", bs.name), "impl", bs.impl) } + bs.Logger.Info(Fmt("Starting %v", bs.name), "impl", bs.impl) err := bs.impl.OnStart() if err != nil { // revert flag @@ -135,10 +134,9 @@ func (bs *BaseService) Start() error { return err } return nil - } else { - bs.Logger.Debug(Fmt("Not starting %v -- already started", bs.name), "impl", bs.impl) - return ErrAlreadyStarted } + bs.Logger.Debug(Fmt("Not starting %v -- already started", bs.name), "impl", bs.impl) + return ErrAlreadyStarted } // OnStart implements Service by doing nothing. @@ -154,10 +152,9 @@ func (bs *BaseService) Stop() error { bs.impl.OnStop() close(bs.quit) return nil - } else { - bs.Logger.Debug(Fmt("Stopping %v (ignoring: already stopped)", bs.name), "impl", bs.impl) - return ErrAlreadyStopped } + bs.Logger.Debug(Fmt("Stopping %v (ignoring: already stopped)", bs.name), "impl", bs.impl) + return ErrAlreadyStopped } // OnStop implements Service by doing nothing. diff --git a/common/string.go b/common/string.go index a6895eb25..0e2231e91 100644 --- a/common/string.go +++ b/common/string.go @@ -6,25 +6,12 @@ import ( "strings" ) -// Fmt shorthand, XXX DEPRECATED -var Fmt = fmt.Sprintf - -// RightPadString adds spaces to the right of a string to make it length totalLength -func RightPadString(s string, totalLength int) string { - remaining := totalLength - len(s) - if remaining > 0 { - s = s + strings.Repeat(" ", remaining) - } - return s -} - -// LeftPadString adds spaces to the left of a string to make it length totalLength -func LeftPadString(s string, totalLength int) string { - remaining := totalLength - len(s) - if remaining > 0 { - s = strings.Repeat(" ", remaining) + s +// Like fmt.Sprintf, but skips formatting if args are empty. +var Fmt = func(format string, a ...interface{}) string { + if len(a) == 0 { + return format } - return s + return fmt.Sprintf(format, a...) } // IsHex returns true for non-empty hex-string prefixed with "0x" @@ -53,3 +40,20 @@ func StringInSlice(a string, list []string) bool { } return false } + +// SplitAndTrim slices s into all subslices separated by sep and returns a +// slice of the string s with all leading and trailing Unicode code points +// contained in cutset removed. If sep is empty, SplitAndTrim splits after each +// UTF-8 sequence. First part is equivalent to strings.SplitN with a count of +// -1. +func SplitAndTrim(s, sep, cutset string) []string { + if s == "" { + return []string{} + } + + spl := strings.Split(s, sep) + for i := 0; i < len(spl); i++ { + spl[i] = strings.Trim(spl[i], cutset) + } + return spl +} diff --git a/common/string_test.go b/common/string_test.go index b8a917c16..82ba67844 100644 --- a/common/string_test.go +++ b/common/string_test.go @@ -30,3 +30,22 @@ func TestIsHex(t *testing.T) { assert.True(t, IsHex(v), "%q is hex", v) } } + +func TestSplitAndTrim(t *testing.T) { + testCases := []struct { + s string + sep string + cutset string + expected []string + }{ + {"a,b,c", ",", " ", []string{"a", "b", "c"}}, + {" a , b , c ", ",", " ", []string{"a", "b", "c"}}, + {" a, b, c ", ",", " ", []string{"a", "b", "c"}}, + {" , ", ",", " ", []string{"", ""}}, + {" ", ",", " ", []string{""}}, + } + + for _, tc := range testCases { + assert.Equal(t, tc.expected, SplitAndTrim(tc.s, tc.sep, tc.cutset), "%s", tc.s) + } +} diff --git a/common/types.pb.go b/common/types.pb.go index c301d28c0..047b7aee2 100644 --- a/common/types.pb.go +++ b/common/types.pb.go @@ -1,6 +1,5 @@ -// Code generated by protoc-gen-gogo. +// Code generated by protoc-gen-gogo. DO NOT EDIT. // source: common/types.proto -// DO NOT EDIT! /* Package common is a generated protocol buffer package. diff --git a/common/word.go b/common/word.go index 4072482b8..a5b841f55 100644 --- a/common/word.go +++ b/common/word.go @@ -72,9 +72,8 @@ func (tuple Tuple256) Compare(other Tuple256) int { firstCompare := tuple.First.Compare(other.First) if firstCompare == 0 { return tuple.Second.Compare(other.Second) - } else { - return firstCompare } + return firstCompare } func Tuple256Split(t Tuple256) (Word256, Word256) { diff --git a/db/backend_test.go b/db/backend_test.go index 80fbbb140..c407b214f 100644 --- a/db/backend_test.go +++ b/db/backend_test.go @@ -47,7 +47,7 @@ func testBackendGetSetDelete(t *testing.T, backend DBBackendType) { } func TestBackendsGetSetDelete(t *testing.T) { - for dbType, _ := range backends { + for dbType := range backends { testBackendGetSetDelete(t, dbType) } } diff --git a/db/c_level_db.go b/db/c_level_db.go index a59137883..e3e6c1d5d 100644 --- a/db/c_level_db.go +++ b/db/c_level_db.go @@ -171,6 +171,14 @@ func (mBatch *cLevelDBBatch) Write() { } } +// Implements Batch. +func (mBatch *cLevelDBBatch) WriteSync() { + err := mBatch.db.db.Write(mBatch.db.woSync, mBatch.batch) + if err != nil { + panic(err) + } +} + //---------------------------------------- // Iterator // NOTE This is almost identical to db/go_level_db.Iterator diff --git a/db/common_test.go b/db/common_test.go index 1b0f00416..1d8d52c5f 100644 --- a/db/common_test.go +++ b/db/common_test.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -9,6 +10,14 @@ import ( cmn "github.com/tendermint/tmlibs/common" ) +//---------------------------------------- +// Helper functions. + +func checkValue(t *testing.T, db DB, key []byte, valueWanted []byte) { + valueGot := db.Get(key) + assert.Equal(t, valueWanted, valueGot) +} + func checkValid(t *testing.T, itr Iterator, expected bool) { valid := itr.Valid() require.Equal(t, expected, valid) @@ -46,110 +55,131 @@ func checkValuePanics(t *testing.T, itr Iterator) { } func newTempDB(t *testing.T, backend DBBackendType) (db DB) { - dir, dirname := cmn.Tempdir("test_go_iterator") + dir, dirname := cmn.Tempdir("db_common_test") db = NewDB("testdb", backend, dirname) dir.Close() return db } -func TestDBIteratorSingleKey(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) - db.SetSync(bz("1"), bz("value_1")) - itr := db.Iterator(nil, nil) +//---------------------------------------- +// mockDB - checkValid(t, itr, true) - checkNext(t, itr, false) - checkValid(t, itr, false) - checkNextPanics(t, itr) +// NOTE: not actually goroutine safe. +// If you want something goroutine safe, maybe you just want a MemDB. +type mockDB struct { + mtx sync.Mutex + calls map[string]int +} - // Once invalid... - checkInvalid(t, itr) - }) +func newMockDB() *mockDB { + return &mockDB{ + calls: make(map[string]int), } } -func TestDBIteratorTwoKeys(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) - db.SetSync(bz("1"), bz("value_1")) - db.SetSync(bz("2"), bz("value_1")) +func (mdb *mockDB) Mutex() *sync.Mutex { + return &(mdb.mtx) +} - { // Fail by calling Next too much - itr := db.Iterator(nil, nil) - checkValid(t, itr, true) +func (mdb *mockDB) Get([]byte) []byte { + mdb.calls["Get"]++ + return nil +} - checkNext(t, itr, true) - checkValid(t, itr, true) +func (mdb *mockDB) Has([]byte) bool { + mdb.calls["Has"]++ + return false +} - checkNext(t, itr, false) - checkValid(t, itr, false) +func (mdb *mockDB) Set([]byte, []byte) { + mdb.calls["Set"]++ +} - checkNextPanics(t, itr) +func (mdb *mockDB) SetSync([]byte, []byte) { + mdb.calls["SetSync"]++ +} - // Once invalid... - checkInvalid(t, itr) - } - }) - } +func (mdb *mockDB) SetNoLock([]byte, []byte) { + mdb.calls["SetNoLock"]++ } -func TestDBIteratorMany(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) +func (mdb *mockDB) SetNoLockSync([]byte, []byte) { + mdb.calls["SetNoLockSync"]++ +} - keys := make([][]byte, 100) - for i := 0; i < 100; i++ { - keys[i] = []byte{byte(i)} - } +func (mdb *mockDB) Delete([]byte) { + mdb.calls["Delete"]++ +} - value := []byte{5} - for _, k := range keys { - db.Set(k, value) - } +func (mdb *mockDB) DeleteSync([]byte) { + mdb.calls["DeleteSync"]++ +} - itr := db.Iterator(nil, nil) - defer itr.Close() - for ; itr.Valid(); itr.Next() { - assert.Equal(t, db.Get(itr.Key()), itr.Value()) - } - }) - } +func (mdb *mockDB) DeleteNoLock([]byte) { + mdb.calls["DeleteNoLock"]++ } -func TestDBIteratorEmpty(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) - itr := db.Iterator(nil, nil) +func (mdb *mockDB) DeleteNoLockSync([]byte) { + mdb.calls["DeleteNoLockSync"]++ +} - checkInvalid(t, itr) - }) - } +func (mdb *mockDB) Iterator(start, end []byte) Iterator { + mdb.calls["Iterator"]++ + return &mockIterator{} } -func TestDBIteratorEmptyBeginAfter(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) - itr := db.Iterator(bz("1"), nil) +func (mdb *mockDB) ReverseIterator(start, end []byte) Iterator { + mdb.calls["ReverseIterator"]++ + return &mockIterator{} +} - checkInvalid(t, itr) - }) - } +func (mdb *mockDB) Close() { + mdb.calls["Close"]++ } -func TestDBIteratorNonemptyBeginAfter(t *testing.T) { - for backend, _ := range backends { - t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { - db := newTempDB(t, backend) - db.SetSync(bz("1"), bz("value_1")) - itr := db.Iterator(bz("2"), nil) +func (mdb *mockDB) NewBatch() Batch { + mdb.calls["NewBatch"]++ + return &memBatch{db: mdb} +} + +func (mdb *mockDB) Print() { + mdb.calls["Print"]++ + fmt.Printf("mockDB{%v}", mdb.Stats()) +} - checkInvalid(t, itr) - }) +func (mdb *mockDB) Stats() map[string]string { + mdb.calls["Stats"]++ + + res := make(map[string]string) + for key, count := range mdb.calls { + res[key] = fmt.Sprintf("%d", count) } + return res +} + +//---------------------------------------- +// mockIterator + +type mockIterator struct{} + +func (mockIterator) Domain() (start []byte, end []byte) { + return nil, nil +} + +func (mockIterator) Valid() bool { + return false +} + +func (mockIterator) Next() { +} + +func (mockIterator) Key() []byte { + return nil +} + +func (mockIterator) Value() []byte { + return nil +} + +func (mockIterator) Close() { } diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 000000000..a56901016 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,194 @@ +package db + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDBIteratorSingleKey(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(nil, nil) + + checkValid(t, itr, true) + checkNext(t, itr, false) + checkValid(t, itr, false) + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorTwoKeys(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + db.SetSync(bz("2"), bz("value_1")) + + { // Fail by calling Next too much + itr := db.Iterator(nil, nil) + checkValid(t, itr, true) + + checkNext(t, itr, true) + checkValid(t, itr, true) + + checkNext(t, itr, false) + checkValid(t, itr, false) + + checkNextPanics(t, itr) + + // Once invalid... + checkInvalid(t, itr) + } + }) + } +} + +func TestDBIteratorMany(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + keys := make([][]byte, 100) + for i := 0; i < 100; i++ { + keys[i] = []byte{byte(i)} + } + + value := []byte{5} + for _, k := range keys { + db.Set(k, value) + } + + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + assert.Equal(t, db.Get(itr.Key()), itr.Value()) + } + }) + } +} + +func TestDBIteratorEmpty(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(nil, nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorEmptyBeginAfter(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + itr := db.Iterator(bz("1"), nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBIteratorNonemptyBeginAfter(t *testing.T) { + for backend := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + db.SetSync(bz("1"), bz("value_1")) + itr := db.Iterator(bz("2"), nil) + + checkInvalid(t, itr) + }) + } +} + +func TestDBBatchWrite1(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Delete(bz("3")) + batch.Set(bz("4"), bz("4")) + batch.Write() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWrite2(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Set(bz("4"), bz("4")) + batch.Delete(bz("3")) + batch.Write() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWriteSync1(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Delete(bz("3")) + batch.Set(bz("4"), bz("4")) + batch.WriteSync() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 2, mdb.calls["SetNoLock"]) + assert.Equal(t, 1, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) +} + +func TestDBBatchWriteSync2(t *testing.T) { + mdb := newMockDB() + ddb := NewDebugDB(t.Name(), mdb) + batch := ddb.NewBatch() + + batch.Set(bz("1"), bz("1")) + batch.Set(bz("2"), bz("2")) + batch.Set(bz("4"), bz("4")) + batch.Delete(bz("3")) + batch.WriteSync() + + assert.Equal(t, 0, mdb.calls["Set"]) + assert.Equal(t, 0, mdb.calls["SetSync"]) + assert.Equal(t, 3, mdb.calls["SetNoLock"]) + assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) + assert.Equal(t, 0, mdb.calls["Delete"]) + assert.Equal(t, 0, mdb.calls["DeleteSync"]) + assert.Equal(t, 0, mdb.calls["DeleteNoLock"]) + assert.Equal(t, 1, mdb.calls["DeleteNoLockSync"]) +} diff --git a/db/debug_db.go b/db/debug_db.go new file mode 100644 index 000000000..7a15bc294 --- /dev/null +++ b/db/debug_db.go @@ -0,0 +1,216 @@ +package db + +import ( + "fmt" + "sync" +) + +//---------------------------------------- +// debugDB + +type debugDB struct { + label string + db DB +} + +// For printing all operationgs to the console for debugging. +func NewDebugDB(label string, db DB) debugDB { + return debugDB{ + label: label, + db: db, + } +} + +// Implements atomicSetDeleter. +func (ddb debugDB) Mutex() *sync.Mutex { return nil } + +// Implements DB. +func (ddb debugDB) Get(key []byte) (value []byte) { + defer fmt.Printf("%v.Get(%X) %X\n", ddb.label, key, value) + value = ddb.db.Get(key) + return +} + +// Implements DB. +func (ddb debugDB) Has(key []byte) (has bool) { + defer fmt.Printf("%v.Has(%X) %v\n", ddb.label, key, has) + return ddb.db.Has(key) +} + +// Implements DB. +func (ddb debugDB) Set(key []byte, value []byte) { + fmt.Printf("%v.Set(%X, %X)\n", ddb.label, key, value) + ddb.db.Set(key, value) +} + +// Implements DB. +func (ddb debugDB) SetSync(key []byte, value []byte) { + fmt.Printf("%v.SetSync(%X, %X)\n", ddb.label, key, value) + ddb.db.SetSync(key, value) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) SetNoLock(key []byte, value []byte) { + fmt.Printf("%v.SetNoLock(%X, %X)\n", ddb.label, key, value) + ddb.db.Set(key, value) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) SetNoLockSync(key []byte, value []byte) { + fmt.Printf("%v.SetNoLockSync(%X, %X)\n", ddb.label, key, value) + ddb.db.SetSync(key, value) +} + +// Implements DB. +func (ddb debugDB) Delete(key []byte) { + fmt.Printf("%v.Delete(%X)\n", ddb.label, key) + ddb.db.Delete(key) +} + +// Implements DB. +func (ddb debugDB) DeleteSync(key []byte) { + fmt.Printf("%v.DeleteSync(%X)\n", ddb.label, key) + ddb.db.DeleteSync(key) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) DeleteNoLock(key []byte) { + fmt.Printf("%v.DeleteNoLock(%X)\n", ddb.label, key) + ddb.db.Delete(key) +} + +// Implements atomicSetDeleter. +func (ddb debugDB) DeleteNoLockSync(key []byte) { + fmt.Printf("%v.DeleteNoLockSync(%X)\n", ddb.label, key) + ddb.db.DeleteSync(key) +} + +// Implements DB. +func (ddb debugDB) Iterator(start, end []byte) Iterator { + fmt.Printf("%v.Iterator(%X, %X)\n", ddb.label, start, end) + return NewDebugIterator(ddb.label, ddb.db.Iterator(start, end)) +} + +// Implements DB. +func (ddb debugDB) ReverseIterator(start, end []byte) Iterator { + fmt.Printf("%v.ReverseIterator(%X, %X)\n", ddb.label, start, end) + return NewDebugIterator(ddb.label, ddb.db.ReverseIterator(start, end)) +} + +// Implements DB. +func (ddb debugDB) NewBatch() Batch { + fmt.Printf("%v.NewBatch()\n", ddb.label) + return NewDebugBatch(ddb.label, ddb.db.NewBatch()) +} + +// Implements DB. +func (ddb debugDB) Close() { + fmt.Printf("%v.Close()\n", ddb.label) + ddb.db.Close() +} + +// Implements DB. +func (ddb debugDB) Print() { + ddb.db.Print() +} + +// Implements DB. +func (ddb debugDB) Stats() map[string]string { + return ddb.db.Stats() +} + +//---------------------------------------- +// debugIterator + +type debugIterator struct { + label string + itr Iterator +} + +// For printing all operationgs to the console for debugging. +func NewDebugIterator(label string, itr Iterator) debugIterator { + return debugIterator{ + label: label, + itr: itr, + } +} + +// Implements Iterator. +func (ditr debugIterator) Domain() (start []byte, end []byte) { + defer fmt.Printf("%v.itr.Domain() (%X,%X)\n", ditr.label, start, end) + start, end = ditr.itr.Domain() + return +} + +// Implements Iterator. +func (ditr debugIterator) Valid() (ok bool) { + defer fmt.Printf("%v.itr.Valid() %v\n", ditr.label, ok) + ok = ditr.itr.Valid() + return +} + +// Implements Iterator. +func (ditr debugIterator) Next() { + fmt.Printf("%v.itr.Next()\n", ditr.label) + ditr.itr.Next() +} + +// Implements Iterator. +func (ditr debugIterator) Key() (key []byte) { + fmt.Printf("%v.itr.Key() %X\n", ditr.label, key) + key = ditr.itr.Key() + return +} + +// Implements Iterator. +func (ditr debugIterator) Value() (value []byte) { + fmt.Printf("%v.itr.Value() %X\n", ditr.label, value) + value = ditr.itr.Value() + return +} + +// Implements Iterator. +func (ditr debugIterator) Close() { + fmt.Printf("%v.itr.Close()\n", ditr.label) + ditr.itr.Close() +} + +//---------------------------------------- +// debugBatch + +type debugBatch struct { + label string + bch Batch +} + +// For printing all operationgs to the console for debugging. +func NewDebugBatch(label string, bch Batch) debugBatch { + return debugBatch{ + label: label, + bch: bch, + } +} + +// Implements Batch. +func (dbch debugBatch) Set(key, value []byte) { + fmt.Printf("%v.batch.Set(%X, %X)\n", dbch.label, key, value) + dbch.bch.Set(key, value) +} + +// Implements Batch. +func (dbch debugBatch) Delete(key []byte) { + fmt.Printf("%v.batch.Delete(%X)\n", dbch.label, key) + dbch.bch.Delete(key) +} + +// Implements Batch. +func (dbch debugBatch) Write() { + fmt.Printf("%v.batch.Write()\n", dbch.label) + dbch.bch.Write() +} + +// Implements Batch. +func (dbch debugBatch) WriteSync() { + fmt.Printf("%v.batch.WriteSync()\n", dbch.label) + dbch.bch.WriteSync() +} diff --git a/db/go_level_db.go b/db/go_level_db.go index 9fed329bf..9ff162e38 100644 --- a/db/go_level_db.go +++ b/db/go_level_db.go @@ -10,7 +10,7 @@ import ( "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) func init() { @@ -46,9 +46,8 @@ func (db *GoLevelDB) Get(key []byte) []byte { if err != nil { if err == errors.ErrNotFound { return nil - } else { - panic(err) } + panic(err) } return res } @@ -64,7 +63,7 @@ func (db *GoLevelDB) Set(key []byte, value []byte) { value = nonNilBytes(value) err := db.db.Put(key, value, nil) if err != nil { - PanicCrisis(err) + cmn.PanicCrisis(err) } } @@ -74,7 +73,7 @@ func (db *GoLevelDB) SetSync(key []byte, value []byte) { value = nonNilBytes(value) err := db.db.Put(key, value, &opt.WriteOptions{Sync: true}) if err != nil { - PanicCrisis(err) + cmn.PanicCrisis(err) } } @@ -83,7 +82,7 @@ func (db *GoLevelDB) Delete(key []byte) { key = nonNilBytes(key) err := db.db.Delete(key, nil) if err != nil { - PanicCrisis(err) + cmn.PanicCrisis(err) } } @@ -92,7 +91,7 @@ func (db *GoLevelDB) DeleteSync(key []byte) { key = nonNilBytes(key) err := db.db.Delete(key, &opt.WriteOptions{Sync: true}) if err != nil { - PanicCrisis(err) + cmn.PanicCrisis(err) } } @@ -110,10 +109,10 @@ func (db *GoLevelDB) Print() { str, _ := db.db.GetProperty("leveldb.stats") fmt.Printf("%v\n", str) - iter := db.db.NewIterator(nil, nil) - for iter.Next() { - key := iter.Key() - value := iter.Value() + itr := db.db.NewIterator(nil, nil) + for itr.Next() { + key := itr.Key() + value := itr.Value() fmt.Printf("[%X]:\t[%X]\n", key, value) } } @@ -167,7 +166,15 @@ func (mBatch *goLevelDBBatch) Delete(key []byte) { // Implements Batch. func (mBatch *goLevelDBBatch) Write() { - err := mBatch.db.db.Write(mBatch.batch, nil) + err := mBatch.db.db.Write(mBatch.batch, &opt.WriteOptions{Sync: false}) + if err != nil { + panic(err) + } +} + +// Implements Batch. +func (mBatch *goLevelDBBatch) WriteSync() { + err := mBatch.db.db.Write(mBatch.batch, &opt.WriteOptions{Sync: true}) if err != nil { panic(err) } diff --git a/db/go_level_db_test.go b/db/go_level_db_test.go index 88b6730f3..266add8b5 100644 --- a/db/go_level_db_test.go +++ b/db/go_level_db_test.go @@ -30,7 +30,7 @@ func BenchmarkRandomReadsWrites(b *testing.B) { // Write something { idx := (int64(cmn.RandInt()) % numItems) - internal[idx] += 1 + internal[idx]++ val := internal[idx] idxBytes := int642Bytes(int64(idx)) valBytes := int642Bytes(int64(val)) diff --git a/db/mem_batch.go b/db/mem_batch.go index 7072d931a..81a63d62b 100644 --- a/db/mem_batch.go +++ b/db/mem_batch.go @@ -5,7 +5,9 @@ import "sync" type atomicSetDeleter interface { Mutex() *sync.Mutex SetNoLock(key, value []byte) + SetNoLockSync(key, value []byte) DeleteNoLock(key []byte) + DeleteNoLockSync(key []byte) } type memBatch struct { @@ -35,16 +37,35 @@ func (mBatch *memBatch) Delete(key []byte) { } func (mBatch *memBatch) Write() { - mtx := mBatch.db.Mutex() - mtx.Lock() - defer mtx.Unlock() + mBatch.write(false) +} + +func (mBatch *memBatch) WriteSync() { + mBatch.write(true) +} - for _, op := range mBatch.ops { +func (mBatch *memBatch) write(doSync bool) { + if mtx := mBatch.db.Mutex(); mtx != nil { + mtx.Lock() + defer mtx.Unlock() + } + + for i, op := range mBatch.ops { + if doSync && i == (len(mBatch.ops)-1) { + switch op.opType { + case opTypeSet: + mBatch.db.SetNoLockSync(op.key, op.value) + case opTypeDelete: + mBatch.db.DeleteNoLockSync(op.key) + } + break // we're done. + } switch op.opType { case opTypeSet: mBatch.db.SetNoLock(op.key, op.value) case opTypeDelete: mBatch.db.DeleteNoLock(op.key) } + } } diff --git a/db/mem_db.go b/db/mem_db.go index f2c484fa7..2d802947c 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -26,6 +26,11 @@ func NewMemDB() *MemDB { return database } +// Implements atomicSetDeleter. +func (db *MemDB) Mutex() *sync.Mutex { + return &(db.mtx) +} + // Implements DB. func (db *MemDB) Get(key []byte) []byte { db.mtx.Lock() @@ -63,6 +68,11 @@ func (db *MemDB) SetSync(key []byte, value []byte) { // Implements atomicSetDeleter. func (db *MemDB) SetNoLock(key []byte, value []byte) { + db.SetNoLockSync(key, value) +} + +// Implements atomicSetDeleter. +func (db *MemDB) SetNoLockSync(key []byte, value []byte) { key = nonNilBytes(key) value = nonNilBytes(value) @@ -87,6 +97,11 @@ func (db *MemDB) DeleteSync(key []byte) { // Implements atomicSetDeleter. func (db *MemDB) DeleteNoLock(key []byte) { + db.DeleteNoLockSync(key) +} + +// Implements atomicSetDeleter. +func (db *MemDB) DeleteNoLockSync(key []byte) { key = nonNilBytes(key) delete(db.db, string(key)) @@ -122,9 +137,6 @@ func (db *MemDB) Stats() map[string]string { return stats } -//---------------------------------------- -// Batch - // Implements DB. func (db *MemDB) NewBatch() Batch { db.mtx.Lock() @@ -133,10 +145,6 @@ func (db *MemDB) NewBatch() Batch { return &memBatch{db, nil} } -func (db *MemDB) Mutex() *sync.Mutex { - return &(db.mtx) -} - //---------------------------------------- // Iterator @@ -227,7 +235,7 @@ func (itr *memDBIterator) assertIsValid() { func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string { keys := []string{} - for key, _ := range db.db { + for key := range db.db { if IsKeyInDomain([]byte(key), start, end, false) { keys = append(keys, key) } diff --git a/db/prefix_db.go b/db/prefix_db.go new file mode 100644 index 000000000..4381ce070 --- /dev/null +++ b/db/prefix_db.go @@ -0,0 +1,263 @@ +package db + +import ( + "bytes" + "fmt" + "sync" +) + +// IteratePrefix is a convenience function for iterating over a key domain +// restricted by prefix. +func IteratePrefix(db DB, prefix []byte) Iterator { + var start, end []byte + if len(prefix) == 0 { + start = nil + end = nil + } else { + start = cp(prefix) + end = cpIncr(prefix) + } + return db.Iterator(start, end) +} + +/* +TODO: Make test, maybe rename. +// Like IteratePrefix but the iterator strips the prefix from the keys. +func IteratePrefixStripped(db DB, prefix []byte) Iterator { + return newUnprefixIterator(prefix, IteratePrefix(db, prefix)) +} +*/ + +//---------------------------------------- +// prefixDB + +type prefixDB struct { + mtx sync.Mutex + prefix []byte + db DB +} + +// NewPrefixDB lets you namespace multiple DBs within a single DB. +func NewPrefixDB(db DB, prefix []byte) *prefixDB { + return &prefixDB{ + prefix: prefix, + db: db, + } +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) Mutex() *sync.Mutex { + return &(pdb.mtx) +} + +// Implements DB. +func (pdb *prefixDB) Get(key []byte) []byte { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + return pdb.db.Get(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) Has(key []byte) bool { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + return pdb.db.Has(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) Set(key []byte, value []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.Set(pdb.prefixed(key), value) +} + +// Implements DB. +func (pdb *prefixDB) SetSync(key []byte, value []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.SetSync(pdb.prefixed(key), value) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) SetNoLock(key []byte, value []byte) { + pdb.db.Set(pdb.prefixed(key), value) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) SetNoLockSync(key []byte, value []byte) { + pdb.db.SetSync(pdb.prefixed(key), value) +} + +// Implements DB. +func (pdb *prefixDB) Delete(key []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.Delete(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) DeleteSync(key []byte) { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.DeleteSync(pdb.prefixed(key)) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) DeleteNoLock(key []byte) { + pdb.db.Delete(pdb.prefixed(key)) +} + +// Implements atomicSetDeleter. +func (pdb *prefixDB) DeleteNoLockSync(key []byte) { + pdb.db.DeleteSync(pdb.prefixed(key)) +} + +// Implements DB. +func (pdb *prefixDB) Iterator(start, end []byte) Iterator { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pstart := append(pdb.prefix, start...) + pend := []byte(nil) + if end != nil { + pend = append(pdb.prefix, end...) + } + return newUnprefixIterator( + pdb.prefix, + pdb.db.Iterator( + pstart, + pend, + ), + ) +} + +// Implements DB. +func (pdb *prefixDB) ReverseIterator(start, end []byte) Iterator { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pstart := []byte(nil) + if start != nil { + pstart = append(pdb.prefix, start...) + } + pend := []byte(nil) + if end != nil { + pend = append(pdb.prefix, end...) + } + return newUnprefixIterator( + pdb.prefix, + pdb.db.ReverseIterator( + pstart, + pend, + ), + ) +} + +// Implements DB. +func (pdb *prefixDB) NewBatch() Batch { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + return &memBatch{pdb, nil} +} + +// Implements DB. +func (pdb *prefixDB) Close() { + pdb.mtx.Lock() + defer pdb.mtx.Unlock() + + pdb.db.Close() +} + +// Implements DB. +func (pdb *prefixDB) Print() { + fmt.Printf("prefix: %X\n", pdb.prefix) + + itr := pdb.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + key := itr.Key() + value := itr.Value() + fmt.Printf("[%X]:\t[%X]\n", key, value) + } +} + +// Implements DB. +func (pdb *prefixDB) Stats() map[string]string { + stats := make(map[string]string) + stats["prefixdb.prefix.string"] = string(pdb.prefix) + stats["prefixdb.prefix.hex"] = fmt.Sprintf("%X", pdb.prefix) + source := pdb.db.Stats() + for key, value := range source { + stats["prefixdb.source."+key] = value + } + return stats +} + +func (pdb *prefixDB) prefixed(key []byte) []byte { + return append(pdb.prefix, key...) +} + +//---------------------------------------- + +// Strips prefix while iterating from Iterator. +type unprefixIterator struct { + prefix []byte + source Iterator +} + +func newUnprefixIterator(prefix []byte, source Iterator) unprefixIterator { + return unprefixIterator{ + prefix: prefix, + source: source, + } +} + +func (itr unprefixIterator) Domain() (start []byte, end []byte) { + start, end = itr.source.Domain() + if len(start) > 0 { + start = stripPrefix(start, itr.prefix) + } + if len(end) > 0 { + end = stripPrefix(end, itr.prefix) + } + return +} + +func (itr unprefixIterator) Valid() bool { + return itr.source.Valid() +} + +func (itr unprefixIterator) Next() { + itr.source.Next() +} + +func (itr unprefixIterator) Key() (key []byte) { + return stripPrefix(itr.source.Key(), itr.prefix) +} + +func (itr unprefixIterator) Value() (value []byte) { + return itr.source.Value() +} + +func (itr unprefixIterator) Close() { + itr.source.Close() +} + +//---------------------------------------- + +func stripPrefix(key []byte, prefix []byte) (stripped []byte) { + if len(key) < len(prefix) { + panic("should not happen") + } + if !bytes.Equal(key[:len(prefix)], prefix) { + panic("should not happne") + } + return key[len(prefix):] +} diff --git a/db/prefix_db_test.go b/db/prefix_db_test.go new file mode 100644 index 000000000..fd44a7ec8 --- /dev/null +++ b/db/prefix_db_test.go @@ -0,0 +1,44 @@ +package db + +import "testing" + +func TestIteratePrefix(t *testing.T) { + db := NewMemDB() + // Under "key" prefix + db.Set(bz("key"), bz("value")) + db.Set(bz("key1"), bz("value1")) + db.Set(bz("key2"), bz("value2")) + db.Set(bz("key3"), bz("value3")) + db.Set(bz("something"), bz("else")) + db.Set(bz(""), bz("")) + db.Set(bz("k"), bz("val")) + db.Set(bz("ke"), bz("valu")) + db.Set(bz("kee"), bz("valuu")) + xitr := db.Iterator(nil, nil) + xitr.Key() + + pdb := NewPrefixDB(db, bz("key")) + checkValue(t, pdb, bz("key"), nil) + checkValue(t, pdb, bz(""), bz("value")) + checkValue(t, pdb, bz("key1"), nil) + checkValue(t, pdb, bz("1"), bz("value1")) + checkValue(t, pdb, bz("key2"), nil) + checkValue(t, pdb, bz("2"), bz("value2")) + checkValue(t, pdb, bz("key3"), nil) + checkValue(t, pdb, bz("3"), bz("value3")) + checkValue(t, pdb, bz("something"), nil) + checkValue(t, pdb, bz("k"), nil) + checkValue(t, pdb, bz("ke"), nil) + checkValue(t, pdb, bz("kee"), nil) + + itr := pdb.Iterator(nil, nil) + itr.Key() + checkItem(t, itr, bz(""), bz("value")) + checkNext(t, itr, true) + checkItem(t, itr, bz("1"), bz("value1")) + checkNext(t, itr, true) + checkItem(t, itr, bz("2"), bz("value2")) + checkNext(t, itr, true) + checkItem(t, itr, bz("3"), bz("value3")) + itr.Close() +} diff --git a/db/types.go b/db/types.go index 07858087a..ad78859a7 100644 --- a/db/types.go +++ b/db/types.go @@ -1,5 +1,6 @@ package db +// DBs are goroutine safe. type DB interface { // Get returns nil iff key doesn't exist. @@ -35,7 +36,7 @@ type DB interface { // Iterate over a domain of keys in descending order. End is exclusive. // Start must be greater than end, or the Iterator is invalid. // If start is nil, iterates from the last/greatest item (inclusive). - // If end is nil, iterates up to the first/least item (iclusive). + // If end is nil, iterates up to the first/least item (inclusive). // CONTRACT: No writes may happen within a domain while an iterator exists over it. // CONTRACT: start, end readonly []byte ReverseIterator(start, end []byte) Iterator @@ -59,6 +60,7 @@ type DB interface { type Batch interface { SetDeleter Write() + WriteSync() } type SetDeleter interface { @@ -127,7 +129,6 @@ func bz(s string) []byte { func nonNilBytes(bz []byte) []byte { if bz == nil { return []byte{} - } else { - return bz } + return bz } diff --git a/db/util.go b/db/util.go index b0ab7f6ad..1ad5002d6 100644 --- a/db/util.go +++ b/db/util.go @@ -4,35 +4,30 @@ import ( "bytes" ) -func IteratePrefix(db DB, prefix []byte) Iterator { - var start, end []byte - if len(prefix) == 0 { - start = nil - end = nil - } else { - start = cp(prefix) - end = cpIncr(prefix) - } - return db.Iterator(start, end) -} - -//---------------------------------------- - func cp(bz []byte) (ret []byte) { ret = make([]byte, len(bz)) copy(ret, bz) return ret } +// Returns a slice of the same length (big endian) +// except incremented by one. +// Returns nil on overflow (e.g. if bz bytes are all 0xFF) // CONTRACT: len(bz) > 0 func cpIncr(bz []byte) (ret []byte) { + if len(bz) == 0 { + panic("cpIncr expects non-zero bz length") + } ret = cp(bz) for i := len(bz) - 1; i >= 0; i-- { if ret[i] < byte(0xFF) { - ret[i] += 1 + ret[i]++ return - } else { - ret[i] = byte(0x00) + } + ret[i] = byte(0x00) + if i == 0 { + // Overflow + return nil } } return nil @@ -48,13 +43,12 @@ func IsKeyInDomain(key, start, end []byte, isReverse bool) bool { return false } return true - } else { - if start != nil && bytes.Compare(start, key) < 0 { - return false - } - if end != nil && bytes.Compare(key, end) <= 0 { - return false - } - return true } + if start != nil && bytes.Compare(start, key) < 0 { + return false + } + if end != nil && bytes.Compare(key, end) <= 0 { + return false + } + return true } diff --git a/db/util_test.go b/db/util_test.go index 854448af3..44f1f9f73 100644 --- a/db/util_test.go +++ b/db/util_test.go @@ -7,7 +7,7 @@ import ( // Empty iterator for empty db. func TestPrefixIteratorNoMatchNil(t *testing.T) { - for backend, _ := range backends { + for backend := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) itr := IteratePrefix(db, []byte("2")) @@ -19,7 +19,7 @@ func TestPrefixIteratorNoMatchNil(t *testing.T) { // Empty iterator for db populated after iterator created. func TestPrefixIteratorNoMatch1(t *testing.T) { - for backend, _ := range backends { + for backend := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) itr := IteratePrefix(db, []byte("2")) @@ -32,7 +32,7 @@ func TestPrefixIteratorNoMatch1(t *testing.T) { // Empty iterator for prefix starting after db entry. func TestPrefixIteratorNoMatch2(t *testing.T) { - for backend, _ := range backends { + for backend := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) db.SetSync(bz("3"), bz("value_3")) @@ -45,7 +45,7 @@ func TestPrefixIteratorNoMatch2(t *testing.T) { // Iterator with single val for db with single val, starting from that val. func TestPrefixIteratorMatch1(t *testing.T) { - for backend, _ := range backends { + for backend := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) db.SetSync(bz("2"), bz("value_2")) @@ -63,7 +63,7 @@ func TestPrefixIteratorMatch1(t *testing.T) { // Iterator with prefix iterates over everything with same prefix. func TestPrefixIteratorMatches1N(t *testing.T) { - for backend, _ := range backends { + for backend := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) diff --git a/events/README.md b/events/README.md index 7a00d79dc..d7469515e 100644 --- a/events/README.md +++ b/events/README.md @@ -95,7 +95,7 @@ type EventCallback func(data EventData) type EventData interface { } ``` -Generic event data can be typed and registered with tendermint/go-wire +Generic event data can be typed and registered with tendermint/go-amino via concrete implementation of this interface diff --git a/events/events.go b/events/events.go index 12aa07813..f1b2a754e 100644 --- a/events/events.go +++ b/events/events.go @@ -6,10 +6,10 @@ package events import ( "sync" - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) -// Generic event data can be typed and registered with tendermint/go-wire +// Generic event data can be typed and registered with tendermint/go-amino // via concrete implementation of this interface type EventData interface { //AssertIsEventData() @@ -27,7 +27,7 @@ type Fireable interface { } type EventSwitch interface { - Service + cmn.Service Fireable AddListenerForEvent(listenerID, event string, cb EventCallback) @@ -36,7 +36,7 @@ type EventSwitch interface { } type eventSwitch struct { - BaseService + cmn.BaseService mtx sync.RWMutex eventCells map[string]*eventCell @@ -45,7 +45,7 @@ type eventSwitch struct { func NewEventSwitch() EventSwitch { evsw := &eventSwitch{} - evsw.BaseService = *NewBaseService(nil, "EventSwitch", evsw) + evsw.BaseService = *cmn.NewBaseService(nil, "EventSwitch", evsw) return evsw } diff --git a/events/events_test.go b/events/events_test.go index 87db2a304..4995ae730 100644 --- a/events/events_test.go +++ b/events/events_test.go @@ -221,11 +221,11 @@ func TestRemoveListener(t *testing.T) { // add some listeners and make sure they work evsw.AddListenerForEvent("listener", "event1", func(data EventData) { - sum1 += 1 + sum1++ }) evsw.AddListenerForEvent("listener", "event2", func(data EventData) { - sum2 += 1 + sum2++ }) for i := 0; i < count; i++ { evsw.FireEvent("event1", true) diff --git a/flowrate/io_test.go b/flowrate/io_test.go index db40337c9..c84029d5e 100644 --- a/flowrate/io_test.go +++ b/flowrate/io_test.go @@ -121,7 +121,15 @@ func TestWriter(t *testing.T) { w.SetBlocking(true) if n, err := w.Write(b[20:]); n != 80 || err != nil { t.Fatalf("w.Write(b[20:]) expected 80 (); got %v (%v)", n, err) - } else if rt := time.Since(start); rt < _400ms { + } else if rt := time.Since(start); rt < _300ms { + // Explanation for `rt < _300ms` (as opposed to `< _400ms`) + // + // |<-- start | | + // epochs: -----0ms|---100ms|---200ms|---300ms|---400ms + // sends: 20|20 |20 |20 |20# + // + // NOTE: The '#' symbol can thus happen before 400ms is up. + // Thus, we can only panic if rt < _300ms. t.Fatalf("w.Write(b[20:]) returned ahead of time (%v)", rt) } diff --git a/merkle/simple_map_test.go b/merkle/simple_map_test.go index 61210132b..c9c871354 100644 --- a/merkle/simple_map_test.go +++ b/merkle/simple_map_test.go @@ -17,37 +17,37 @@ func TestSimpleMap(t *testing.T) { { db := NewSimpleMap() db.Set("key1", strHasher("value1")) - assert.Equal(t, "19618304d1ad2635c4238bce87f72331b22a11a1", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "acdb4f121bc6f25041eb263ab463f1cd79236a32", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } { db := NewSimpleMap() db.Set("key1", strHasher("value2")) - assert.Equal(t, "51cb96d3d41e1714def72eb4bacc211de9ddf284", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "b8cbf5adee8c524e14f531da9b49adbbbd66fffa", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } { db := NewSimpleMap() db.Set("key1", strHasher("value1")) db.Set("key2", strHasher("value2")) - assert.Equal(t, "58a0a99d5019fdcad4bcf55942e833b2dfab9421", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "1708aabc85bbe00242d3db8c299516aa54e48c38", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } { db := NewSimpleMap() db.Set("key2", strHasher("value2")) // NOTE: out of order db.Set("key1", strHasher("value1")) - assert.Equal(t, "58a0a99d5019fdcad4bcf55942e833b2dfab9421", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "1708aabc85bbe00242d3db8c299516aa54e48c38", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } { db := NewSimpleMap() db.Set("key1", strHasher("value1")) db.Set("key2", strHasher("value2")) db.Set("key3", strHasher("value3")) - assert.Equal(t, "cb56db3c7993e977f4c2789559ae3e5e468a6e9b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "e728afe72ce351eed6aca65c5f78da19b9a6e214", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } { db := NewSimpleMap() db.Set("key2", strHasher("value2")) // NOTE: out of order db.Set("key1", strHasher("value1")) db.Set("key3", strHasher("value3")) - assert.Equal(t, "cb56db3c7993e977f4c2789559ae3e5e468a6e9b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") + assert.Equal(t, "e728afe72ce351eed6aca65c5f78da19b9a6e214", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") } } diff --git a/merkle/simple_proof.go b/merkle/simple_proof.go index 83f89e598..c81ed674a 100644 --- a/merkle/simple_proof.go +++ b/merkle/simple_proof.go @@ -43,9 +43,9 @@ func (sp *SimpleProof) StringIndented(indent string) string { // Use the leafHash and innerHashes to get the root merkle hash. // If the length of the innerHashes slice isn't exactly correct, the result is nil. +// Recursive impl. func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { - // Recursive impl. - if index >= total { + if index >= total || index < 0 || total <= 0 { return nil } switch total { @@ -67,13 +67,12 @@ func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][ return nil } return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) - } else { - rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) - if rightHash == nil { - return nil - } - return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) } + rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) + if rightHash == nil { + return nil + } + return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) } } @@ -81,7 +80,7 @@ func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][ // The node and the tree is thrown away afterwards. // Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. // node.Parent.Hash = hash(node.Hash, node.Right.Hash) or -// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. +// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. type SimpleProofNode struct { Hash []byte Parent *SimpleProofNode diff --git a/merkle/simple_tree.go b/merkle/simple_tree.go index a363ea8e8..9bdf52cb2 100644 --- a/merkle/simple_tree.go +++ b/merkle/simple_tree.go @@ -31,6 +31,9 @@ import ( func SimpleHashFromTwoHashes(left []byte, right []byte) []byte { var hasher = ripemd160.New() err := encodeByteSlice(hasher, left) + if err != nil { + panic(err) + } err = encodeByteSlice(hasher, right) if err != nil { panic(err) diff --git a/merkle/types.go b/merkle/types.go index e0fe35fa8..a0c491a7e 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -28,10 +28,10 @@ type Hasher interface { } //----------------------------------------------------------------------- -// NOTE: these are duplicated from go-wire so we dont need go-wire as a dep +// NOTE: these are duplicated from go-amino so we dont need go-amino as a dep func encodeByteSlice(w io.Writer, bz []byte) (err error) { - err = encodeVarint(w, int64(len(bz))) + err = encodeUvarint(w, uint64(len(bz))) if err != nil { return } @@ -39,9 +39,9 @@ func encodeByteSlice(w io.Writer, bz []byte) (err error) { return } -func encodeVarint(w io.Writer, i int64) (err error) { +func encodeUvarint(w io.Writer, i uint64) (err error) { var buf [10]byte - n := binary.PutVarint(buf[:], i) + n := binary.PutUvarint(buf[:], i) _, err = w.Write(buf[0:n]) return } diff --git a/pubsub/pubsub.go b/pubsub/pubsub.go index 54a4b8aed..90f6e4ae6 100644 --- a/pubsub/pubsub.go +++ b/pubsub/pubsub.go @@ -28,6 +28,16 @@ const ( shutdown ) +var ( + // ErrSubscriptionNotFound is returned when a client tries to unsubscribe + // from not existing subscription. + ErrSubscriptionNotFound = errors.New("subscription not found") + + // ErrAlreadySubscribed is returned when a client tries to subscribe twice or + // more using the same query. + ErrAlreadySubscribed = errors.New("already subscribed") +) + type cmd struct { op operation query Query @@ -52,7 +62,7 @@ type Server struct { cmdsCap int mtx sync.RWMutex - subscriptions map[string]map[string]struct{} // subscriber -> query -> struct{} + subscriptions map[string]map[string]Query // subscriber -> query (string) -> Query } // Option sets a parameter for the server. @@ -63,7 +73,7 @@ type Option func(*Server) // provided, the resulting server's queue is unbuffered. func NewServer(options ...Option) *Server { s := &Server{ - subscriptions: make(map[string]map[string]struct{}), + subscriptions: make(map[string]map[string]Query), } s.BaseService = *cmn.NewBaseService(nil, "PubSub", s) @@ -106,16 +116,16 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou } s.mtx.RUnlock() if ok { - return errors.New("already subscribed") + return ErrAlreadySubscribed } select { case s.cmds <- cmd{op: sub, clientID: clientID, query: query, ch: out}: s.mtx.Lock() if _, ok = s.subscriptions[clientID]; !ok { - s.subscriptions[clientID] = make(map[string]struct{}) + s.subscriptions[clientID] = make(map[string]Query) } - s.subscriptions[clientID][query.String()] = struct{}{} + s.subscriptions[clientID][query.String()] = query s.mtx.Unlock() return nil case <-ctx.Done(): @@ -127,18 +137,20 @@ func (s *Server) Subscribe(ctx context.Context, clientID string, query Query, ou // returned to the caller if the context is canceled or if subscription does // not exist. func (s *Server) Unsubscribe(ctx context.Context, clientID string, query Query) error { + var origQuery Query s.mtx.RLock() clientSubscriptions, ok := s.subscriptions[clientID] if ok { - _, ok = clientSubscriptions[query.String()] + origQuery, ok = clientSubscriptions[query.String()] } s.mtx.RUnlock() if !ok { - return errors.New("subscription not found") + return ErrSubscriptionNotFound } + // original query is used here because we're using pointers as map keys select { - case s.cmds <- cmd{op: unsub, clientID: clientID, query: query}: + case s.cmds <- cmd{op: unsub, clientID: clientID, query: origQuery}: s.mtx.Lock() delete(clientSubscriptions, query.String()) s.mtx.Unlock() @@ -155,7 +167,7 @@ func (s *Server) UnsubscribeAll(ctx context.Context, clientID string) error { _, ok := s.subscriptions[clientID] s.mtx.RUnlock() if !ok { - return errors.New("subscription not found") + return ErrSubscriptionNotFound } select { @@ -209,6 +221,11 @@ func (s *Server) OnStart() error { return nil } +// OnReset implements Service.OnReset +func (s *Server) OnReset() error { + return nil +} + func (s *Server) loop(state state) { loop: for cmd := range s.cmds { diff --git a/pubsub/pubsub_test.go b/pubsub/pubsub_test.go index 84b6aa218..2af7cea46 100644 --- a/pubsub/pubsub_test.go +++ b/pubsub/pubsub_test.go @@ -101,9 +101,9 @@ func TestUnsubscribe(t *testing.T) { ctx := context.Background() ch := make(chan interface{}) - err := s.Subscribe(ctx, clientID, query.Empty{}, ch) + err := s.Subscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'"), ch) require.NoError(t, err) - err = s.Unsubscribe(ctx, clientID, query.Empty{}) + err = s.Unsubscribe(ctx, clientID, query.MustParse("tm.events.type='NewBlock'")) require.NoError(t, err) err = s.Publish(ctx, "Nick Fury") diff --git a/test/mutate.go b/test/mutate.go index 1dbe7a6bf..76534e8b1 100644 --- a/test/mutate.go +++ b/test/mutate.go @@ -1,7 +1,7 @@ package test import ( - . "github.com/tendermint/tmlibs/common" + cmn "github.com/tendermint/tmlibs/common" ) // Contract: !bytes.Equal(input, output) && len(input) >= len(output) @@ -17,11 +17,11 @@ func MutateByteSlice(bytez []byte) []byte { bytez = mBytez // Try a random mutation - switch RandInt() % 2 { + switch cmn.RandInt() % 2 { case 0: // Mutate a single byte - bytez[RandInt()%len(bytez)] += byte(RandInt()%255 + 1) + bytez[cmn.RandInt()%len(bytez)] += byte(cmn.RandInt()%255 + 1) case 1: // Remove an arbitrary byte - pos := RandInt() % len(bytez) + pos := cmn.RandInt() % len(bytez) bytez = append(bytez[:pos], bytez[pos+1:]...) } return bytez diff --git a/version/version.go b/version/version.go index 5449f1478..b389a63a0 100644 --- a/version/version.go +++ b/version/version.go @@ -1,3 +1,3 @@ package version -const Version = "0.7.1" +const Version = "0.8.1"