@ -1,15 +1,148 @@ | |||||
package common | package common | ||||
import "sync" | |||||
func Parallel(tasks ...func()) { | |||||
var wg sync.WaitGroup | |||||
wg.Add(len(tasks)) | |||||
for _, task := range tasks { | |||||
go func(task func()) { | |||||
task() | |||||
wg.Done() | |||||
}(task) | |||||
} | |||||
wg.Wait() | |||||
import ( | |||||
"sync/atomic" | |||||
) | |||||
//---------------------------------------- | |||||
// Task | |||||
// val: the value returned after task execution. | |||||
// err: the error returned during task completion. | |||||
// abort: tells Parallel to return, whether or not all tasks have completed. | |||||
type Task func(i int) (val interface{}, err error, abort bool) | |||||
type TaskResult struct { | |||||
Value interface{} | |||||
Error error | |||||
} | |||||
type TaskResultCh <-chan TaskResult | |||||
type taskResultOK struct { | |||||
TaskResult | |||||
OK bool | |||||
} | |||||
type TaskResultSet struct { | |||||
chz []TaskResultCh | |||||
results []taskResultOK | |||||
} | |||||
func newTaskResultSet(chz []TaskResultCh) *TaskResultSet { | |||||
return &TaskResultSet{ | |||||
chz: chz, | |||||
results: nil, | |||||
} | |||||
} | |||||
func (trs *TaskResultSet) Channels() []TaskResultCh { | |||||
return trs.chz | |||||
} | |||||
func (trs *TaskResultSet) LatestResult(index int) (TaskResult, bool) { | |||||
if len(trs.results) <= index { | |||||
return TaskResult{}, false | |||||
} | |||||
resultOK := trs.results[index] | |||||
return resultOK.TaskResult, resultOK.OK | |||||
} | |||||
// NOTE: Not concurrency safe. | |||||
func (trs *TaskResultSet) Reap() *TaskResultSet { | |||||
if trs.results == nil { | |||||
trs.results = make([]taskResultOK, len(trs.chz)) | |||||
} | |||||
for i := 0; i < len(trs.results); i++ { | |||||
var trch = trs.chz[i] | |||||
select { | |||||
case result := <-trch: | |||||
// Overwrite result. | |||||
trs.results[i] = taskResultOK{ | |||||
TaskResult: result, | |||||
OK: true, | |||||
} | |||||
default: | |||||
// Do nothing. | |||||
} | |||||
} | |||||
return trs | |||||
} | |||||
// Returns the firstmost (by task index) error as | |||||
// discovered by all previous Reap() calls. | |||||
func (trs *TaskResultSet) FirstValue() interface{} { | |||||
for _, result := range trs.results { | |||||
if result.Value != nil { | |||||
return result.Value | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
// Returns the firstmost (by task index) error as | |||||
// discovered by all previous Reap() calls. | |||||
func (trs *TaskResultSet) FirstError() error { | |||||
for _, result := range trs.results { | |||||
if result.Error != nil { | |||||
return result.Error | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
//---------------------------------------- | |||||
// Parallel | |||||
// Run tasks in parallel, with ability to abort early. | |||||
// Returns ok=false iff any of the tasks returned abort=true. | |||||
// NOTE: Do not implement quit features here. Instead, provide convenient | |||||
// concurrent quit-like primitives, passed implicitly via Task closures. (e.g. | |||||
// it's not Parallel's concern how you quit/abort your tasks). | |||||
func Parallel(tasks ...Task) (trs *TaskResultSet, ok bool) { | |||||
var taskResultChz = make([]TaskResultCh, len(tasks)) // To return. | |||||
var taskDoneCh = make(chan bool, len(tasks)) // A "wait group" channel, early abort if any true received. | |||||
var numPanics = new(int32) // Keep track of panics to set ok=false later. | |||||
ok = true // We will set it to false iff any tasks panic'd or returned abort. | |||||
// Start all tasks in parallel in separate goroutines. | |||||
// When the task is complete, it will appear in the | |||||
// respective taskResultCh (associated by task index). | |||||
for i, task := range tasks { | |||||
var taskResultCh = make(chan TaskResult, 1) // Capacity for 1 result. | |||||
taskResultChz[i] = taskResultCh | |||||
go func(i int, task Task, taskResultCh chan TaskResult) { | |||||
// Recovery | |||||
defer func() { | |||||
if pnk := recover(); pnk != nil { | |||||
atomic.AddInt32(numPanics, 1) | |||||
taskResultCh <- TaskResult{nil, ErrorWrap(pnk, "Panic in task")} | |||||
taskDoneCh <- false | |||||
} | |||||
}() | |||||
// Run the task. | |||||
var val, err, abort = task(i) | |||||
// Send val/err to taskResultCh. | |||||
// NOTE: Below this line, nothing must panic/ | |||||
taskResultCh <- TaskResult{val, err} | |||||
// Decrement waitgroup. | |||||
taskDoneCh <- abort | |||||
}(i, task, taskResultCh) | |||||
} | |||||
// Wait until all tasks are done, or until abort. | |||||
// DONE_LOOP: | |||||
for i := 0; i < len(tasks); i++ { | |||||
abort := <-taskDoneCh | |||||
if abort { | |||||
ok = false | |||||
break | |||||
} | |||||
} | |||||
// Ok is also false if there were any panics. | |||||
// We must do this check here (after DONE_LOOP). | |||||
ok = ok && (atomic.LoadInt32(numPanics) == 0) | |||||
return newTaskResultSet(taskResultChz).Reap(), ok | |||||
} | } |
@ -0,0 +1,152 @@ | |||||
package common | |||||
import ( | |||||
"errors" | |||||
"fmt" | |||||
"sync/atomic" | |||||
"testing" | |||||
"time" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func TestParallel(t *testing.T) { | |||||
// Create tasks. | |||||
var counter = new(int32) | |||||
var tasks = make([]Task, 100*1000) | |||||
for i := 0; i < len(tasks); i++ { | |||||
tasks[i] = func(i int) (res interface{}, err error, abort bool) { | |||||
atomic.AddInt32(counter, 1) | |||||
return -1 * i, nil, false | |||||
} | |||||
} | |||||
// Run in parallel. | |||||
var trs, ok = Parallel(tasks...) | |||||
assert.True(t, ok) | |||||
// Verify. | |||||
assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") | |||||
var failedTasks int | |||||
for i := 0; i < len(tasks); i++ { | |||||
taskResult, ok := trs.LatestResult(i) | |||||
if !ok { | |||||
assert.Fail(t, "Task #%v did not complete.", i) | |||||
failedTasks++ | |||||
} else if taskResult.Error != nil { | |||||
assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) | |||||
failedTasks++ | |||||
} else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { | |||||
assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) | |||||
failedTasks++ | |||||
} else { | |||||
// Good! | |||||
} | |||||
} | |||||
assert.Equal(t, failedTasks, 0, "No task should have failed") | |||||
assert.Nil(t, trs.FirstError(), "There should be no errors") | |||||
assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") | |||||
} | |||||
func TestParallelAbort(t *testing.T) { | |||||
var flow1 = make(chan struct{}, 1) | |||||
var flow2 = make(chan struct{}, 1) | |||||
var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. | |||||
var flow4 = make(chan struct{}, 1) | |||||
// Create tasks. | |||||
var tasks = []Task{ | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
assert.Equal(t, i, 0) | |||||
flow1 <- struct{}{} | |||||
return 0, nil, false | |||||
}, | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
assert.Equal(t, i, 1) | |||||
flow2 <- <-flow1 | |||||
return 1, errors.New("some error"), false | |||||
}, | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
assert.Equal(t, i, 2) | |||||
flow3 <- <-flow2 | |||||
return 2, nil, true | |||||
}, | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
assert.Equal(t, i, 3) | |||||
<-flow4 | |||||
return 3, nil, false | |||||
}, | |||||
} | |||||
// Run in parallel. | |||||
var taskResultSet, ok = Parallel(tasks...) | |||||
assert.False(t, ok, "ok should be false since we aborted task #2.") | |||||
// Verify task #3. | |||||
// Initially taskResultSet.chz[3] sends nothing since flow4 didn't send. | |||||
waitTimeout(t, taskResultSet.chz[3], "Task #3") | |||||
// Now let the last task (#3) complete after abort. | |||||
flow4 <- <-flow3 | |||||
// Verify task #0, #1, #2. | |||||
checkResult(t, taskResultSet, 0, 0, nil, nil) | |||||
checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) | |||||
checkResult(t, taskResultSet, 2, 2, nil, nil) | |||||
} | |||||
func TestParallelRecover(t *testing.T) { | |||||
// Create tasks. | |||||
var tasks = []Task{ | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
return 0, nil, false | |||||
}, | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
return 1, errors.New("some error"), false | |||||
}, | |||||
func(i int) (res interface{}, err error, abort bool) { | |||||
panic(2) | |||||
}, | |||||
} | |||||
// Run in parallel. | |||||
var taskResultSet, ok = Parallel(tasks...) | |||||
assert.False(t, ok, "ok should be false since we panic'd in task #2.") | |||||
// Verify task #0, #1, #2. | |||||
checkResult(t, taskResultSet, 0, 0, nil, nil) | |||||
checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) | |||||
checkResult(t, taskResultSet, 2, nil, nil, 2) | |||||
} | |||||
// Wait for result | |||||
func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { | |||||
taskResult, ok := taskResultSet.LatestResult(index) | |||||
taskName := fmt.Sprintf("Task #%v", index) | |||||
assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) | |||||
assert.Equal(t, val, taskResult.Value, taskName) | |||||
if err != nil { | |||||
assert.Equal(t, err, taskResult.Error, taskName) | |||||
} else if pnk != nil { | |||||
assert.Equal(t, pnk, taskResult.Error.(Error).Cause(), taskName) | |||||
} else { | |||||
assert.Nil(t, taskResult.Error, taskName) | |||||
} | |||||
} | |||||
// Wait for timeout (no result) | |||||
func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { | |||||
select { | |||||
case _, ok := <-taskResultCh: | |||||
if !ok { | |||||
assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) | |||||
} else { | |||||
assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) | |||||
} | |||||
case <-time.After(1 * time.Second): // TODO use deterministic time? | |||||
// Good! | |||||
} | |||||
} |
@ -0,0 +1,107 @@ | |||||
package common | |||||
import ( | |||||
fmt "fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func TestErrorPanic(t *testing.T) { | |||||
type pnk struct { | |||||
msg string | |||||
} | |||||
capturePanic := func() (err Error) { | |||||
defer func() { | |||||
if r := recover(); r != nil { | |||||
err = ErrorWrap(r, "This is the message in ErrorWrap(r, message).") | |||||
} | |||||
return | |||||
}() | |||||
panic(pnk{"something"}) | |||||
return nil | |||||
} | |||||
var err = capturePanic() | |||||
assert.Equal(t, pnk{"something"}, err.Cause()) | |||||
assert.Equal(t, pnk{"something"}, err.T()) | |||||
assert.Equal(t, "This is the message in ErrorWrap(r, message).", err.Message()) | |||||
assert.Equal(t, "Error{`This is the message in ErrorWrap(r, message).` (cause: {something})}", fmt.Sprintf("%v", err)) | |||||
assert.Contains(t, fmt.Sprintf("%#v", err), "Message: This is the message in ErrorWrap(r, message).") | |||||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||||
} | |||||
func TestErrorWrapSomething(t *testing.T) { | |||||
var err = ErrorWrap("something", "formatter%v%v", 0, 1) | |||||
assert.Equal(t, "something", err.Cause()) | |||||
assert.Equal(t, "something", err.T()) | |||||
assert.Equal(t, "formatter01", err.Message()) | |||||
assert.Equal(t, "Error{`formatter01` (cause: something)}", fmt.Sprintf("%v", err)) | |||||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||||
} | |||||
func TestErrorWrapNothing(t *testing.T) { | |||||
var err = ErrorWrap(nil, "formatter%v%v", 0, 1) | |||||
assert.Equal(t, nil, err.Cause()) | |||||
assert.Equal(t, nil, err.T()) | |||||
assert.Equal(t, "formatter01", err.Message()) | |||||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||||
} | |||||
func TestErrorNewError(t *testing.T) { | |||||
var err = NewError("formatter%v%v", 0, 1) | |||||
assert.Equal(t, nil, err.Cause()) | |||||
assert.Equal(t, "formatter%v%v", err.T()) | |||||
assert.Equal(t, "formatter01", err.Message()) | |||||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||||
assert.NotContains(t, fmt.Sprintf("%#v", err), "Stack Trace") | |||||
} | |||||
func TestErrorNewErrorWithStacktrace(t *testing.T) { | |||||
var err = NewError("formatter%v%v", 0, 1).Stacktrace() | |||||
assert.Equal(t, nil, err.Cause()) | |||||
assert.Equal(t, "formatter%v%v", err.T()) | |||||
assert.Equal(t, "formatter01", err.Message()) | |||||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||||
} | |||||
func TestErrorNewErrorWithTrace(t *testing.T) { | |||||
var err = NewError("formatter%v%v", 0, 1) | |||||
err.Trace("trace %v", 1) | |||||
err.Trace("trace %v", 2) | |||||
err.Trace("trace %v", 3) | |||||
assert.Equal(t, nil, err.Cause()) | |||||
assert.Equal(t, "formatter%v%v", err.T()) | |||||
assert.Equal(t, "formatter01", err.Message()) | |||||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||||
dump := fmt.Sprintf("%#v", err) | |||||
assert.NotContains(t, dump, "Stack Trace") | |||||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 1`, dump) | |||||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 2`, dump) | |||||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 3`, dump) | |||||
} | |||||
func TestErrorWrapError(t *testing.T) { | |||||
var err1 error = NewError("my message") | |||||
var err2 error = ErrorWrap(err1, "another message") | |||||
assert.Equal(t, err1, err2) | |||||
} |
@ -0,0 +1,29 @@ | |||||
package common | |||||
import "reflect" | |||||
// Go lacks a simple and safe way to see if something is a typed nil. | |||||
// See: | |||||
// - https://dave.cheney.net/2017/08/09/typed-nils-in-go-2 | |||||
// - https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I/discussion | |||||
// - https://github.com/golang/go/issues/21538 | |||||
func IsTypedNil(o interface{}) bool { | |||||
rv := reflect.ValueOf(o) | |||||
switch rv.Kind() { | |||||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice: | |||||
return rv.IsNil() | |||||
default: | |||||
return false | |||||
} | |||||
} | |||||
// Returns true if it has zero length. | |||||
func IsEmpty(o interface{}) bool { | |||||
rv := reflect.ValueOf(o) | |||||
switch rv.Kind() { | |||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: | |||||
return rv.Len() == 0 | |||||
default: | |||||
return false | |||||
} | |||||
} |
@ -0,0 +1,194 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func TestDBIteratorSingleKey(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
itr := db.Iterator(nil, nil) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, false) | |||||
checkValid(t, itr, false) | |||||
checkNextPanics(t, itr) | |||||
// Once invalid... | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorTwoKeys(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
db.SetSync(bz("2"), bz("value_1")) | |||||
{ // Fail by calling Next too much | |||||
itr := db.Iterator(nil, nil) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, true) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, false) | |||||
checkValid(t, itr, false) | |||||
checkNextPanics(t, itr) | |||||
// Once invalid... | |||||
checkInvalid(t, itr) | |||||
} | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorMany(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
keys := make([][]byte, 100) | |||||
for i := 0; i < 100; i++ { | |||||
keys[i] = []byte{byte(i)} | |||||
} | |||||
value := []byte{5} | |||||
for _, k := range keys { | |||||
db.Set(k, value) | |||||
} | |||||
itr := db.Iterator(nil, nil) | |||||
defer itr.Close() | |||||
for ; itr.Valid(); itr.Next() { | |||||
assert.Equal(t, db.Get(itr.Key()), itr.Value()) | |||||
} | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorEmpty(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := db.Iterator(nil, nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorEmptyBeginAfter(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := db.Iterator(bz("1"), nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorNonemptyBeginAfter(t *testing.T) { | |||||
for backend := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
itr := db.Iterator(bz("2"), nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBBatchWrite1(t *testing.T) { | |||||
mdb := newMockDB() | |||||
ddb := NewDebugDB(t.Name(), mdb) | |||||
batch := ddb.NewBatch() | |||||
batch.Set(bz("1"), bz("1")) | |||||
batch.Set(bz("2"), bz("2")) | |||||
batch.Delete(bz("3")) | |||||
batch.Set(bz("4"), bz("4")) | |||||
batch.Write() | |||||
assert.Equal(t, 0, mdb.calls["Set"]) | |||||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||||
} | |||||
func TestDBBatchWrite2(t *testing.T) { | |||||
mdb := newMockDB() | |||||
ddb := NewDebugDB(t.Name(), mdb) | |||||
batch := ddb.NewBatch() | |||||
batch.Set(bz("1"), bz("1")) | |||||
batch.Set(bz("2"), bz("2")) | |||||
batch.Set(bz("4"), bz("4")) | |||||
batch.Delete(bz("3")) | |||||
batch.Write() | |||||
assert.Equal(t, 0, mdb.calls["Set"]) | |||||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||||
} | |||||
func TestDBBatchWriteSync1(t *testing.T) { | |||||
mdb := newMockDB() | |||||
ddb := NewDebugDB(t.Name(), mdb) | |||||
batch := ddb.NewBatch() | |||||
batch.Set(bz("1"), bz("1")) | |||||
batch.Set(bz("2"), bz("2")) | |||||
batch.Delete(bz("3")) | |||||
batch.Set(bz("4"), bz("4")) | |||||
batch.WriteSync() | |||||
assert.Equal(t, 0, mdb.calls["Set"]) | |||||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||||
assert.Equal(t, 2, mdb.calls["SetNoLock"]) | |||||
assert.Equal(t, 1, mdb.calls["SetNoLockSync"]) | |||||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||||
} | |||||
func TestDBBatchWriteSync2(t *testing.T) { | |||||
mdb := newMockDB() | |||||
ddb := NewDebugDB(t.Name(), mdb) | |||||
batch := ddb.NewBatch() | |||||
batch.Set(bz("1"), bz("1")) | |||||
batch.Set(bz("2"), bz("2")) | |||||
batch.Set(bz("4"), bz("4")) | |||||
batch.Delete(bz("3")) | |||||
batch.WriteSync() | |||||
assert.Equal(t, 0, mdb.calls["Set"]) | |||||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||||
assert.Equal(t, 0, mdb.calls["DeleteNoLock"]) | |||||
assert.Equal(t, 1, mdb.calls["DeleteNoLockSync"]) | |||||
} |
@ -0,0 +1,216 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"sync" | |||||
) | |||||
//---------------------------------------- | |||||
// debugDB | |||||
type debugDB struct { | |||||
label string | |||||
db DB | |||||
} | |||||
// For printing all operationgs to the console for debugging. | |||||
func NewDebugDB(label string, db DB) debugDB { | |||||
return debugDB{ | |||||
label: label, | |||||
db: db, | |||||
} | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (ddb debugDB) Mutex() *sync.Mutex { return nil } | |||||
// Implements DB. | |||||
func (ddb debugDB) Get(key []byte) (value []byte) { | |||||
defer fmt.Printf("%v.Get(%X) %X\n", ddb.label, key, value) | |||||
value = ddb.db.Get(key) | |||||
return | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Has(key []byte) (has bool) { | |||||
defer fmt.Printf("%v.Has(%X) %v\n", ddb.label, key, has) | |||||
return ddb.db.Has(key) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Set(key []byte, value []byte) { | |||||
fmt.Printf("%v.Set(%X, %X)\n", ddb.label, key, value) | |||||
ddb.db.Set(key, value) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) SetSync(key []byte, value []byte) { | |||||
fmt.Printf("%v.SetSync(%X, %X)\n", ddb.label, key, value) | |||||
ddb.db.SetSync(key, value) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (ddb debugDB) SetNoLock(key []byte, value []byte) { | |||||
fmt.Printf("%v.SetNoLock(%X, %X)\n", ddb.label, key, value) | |||||
ddb.db.Set(key, value) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (ddb debugDB) SetNoLockSync(key []byte, value []byte) { | |||||
fmt.Printf("%v.SetNoLockSync(%X, %X)\n", ddb.label, key, value) | |||||
ddb.db.SetSync(key, value) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Delete(key []byte) { | |||||
fmt.Printf("%v.Delete(%X)\n", ddb.label, key) | |||||
ddb.db.Delete(key) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) DeleteSync(key []byte) { | |||||
fmt.Printf("%v.DeleteSync(%X)\n", ddb.label, key) | |||||
ddb.db.DeleteSync(key) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (ddb debugDB) DeleteNoLock(key []byte) { | |||||
fmt.Printf("%v.DeleteNoLock(%X)\n", ddb.label, key) | |||||
ddb.db.Delete(key) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (ddb debugDB) DeleteNoLockSync(key []byte) { | |||||
fmt.Printf("%v.DeleteNoLockSync(%X)\n", ddb.label, key) | |||||
ddb.db.DeleteSync(key) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Iterator(start, end []byte) Iterator { | |||||
fmt.Printf("%v.Iterator(%X, %X)\n", ddb.label, start, end) | |||||
return NewDebugIterator(ddb.label, ddb.db.Iterator(start, end)) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) ReverseIterator(start, end []byte) Iterator { | |||||
fmt.Printf("%v.ReverseIterator(%X, %X)\n", ddb.label, start, end) | |||||
return NewDebugIterator(ddb.label, ddb.db.ReverseIterator(start, end)) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) NewBatch() Batch { | |||||
fmt.Printf("%v.NewBatch()\n", ddb.label) | |||||
return NewDebugBatch(ddb.label, ddb.db.NewBatch()) | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Close() { | |||||
fmt.Printf("%v.Close()\n", ddb.label) | |||||
ddb.db.Close() | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Print() { | |||||
ddb.db.Print() | |||||
} | |||||
// Implements DB. | |||||
func (ddb debugDB) Stats() map[string]string { | |||||
return ddb.db.Stats() | |||||
} | |||||
//---------------------------------------- | |||||
// debugIterator | |||||
type debugIterator struct { | |||||
label string | |||||
itr Iterator | |||||
} | |||||
// For printing all operationgs to the console for debugging. | |||||
func NewDebugIterator(label string, itr Iterator) debugIterator { | |||||
return debugIterator{ | |||||
label: label, | |||||
itr: itr, | |||||
} | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Domain() (start []byte, end []byte) { | |||||
defer fmt.Printf("%v.itr.Domain() (%X,%X)\n", ditr.label, start, end) | |||||
start, end = ditr.itr.Domain() | |||||
return | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Valid() (ok bool) { | |||||
defer fmt.Printf("%v.itr.Valid() %v\n", ditr.label, ok) | |||||
ok = ditr.itr.Valid() | |||||
return | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Next() { | |||||
fmt.Printf("%v.itr.Next()\n", ditr.label) | |||||
ditr.itr.Next() | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Key() (key []byte) { | |||||
fmt.Printf("%v.itr.Key() %X\n", ditr.label, key) | |||||
key = ditr.itr.Key() | |||||
return | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Value() (value []byte) { | |||||
fmt.Printf("%v.itr.Value() %X\n", ditr.label, value) | |||||
value = ditr.itr.Value() | |||||
return | |||||
} | |||||
// Implements Iterator. | |||||
func (ditr debugIterator) Close() { | |||||
fmt.Printf("%v.itr.Close()\n", ditr.label) | |||||
ditr.itr.Close() | |||||
} | |||||
//---------------------------------------- | |||||
// debugBatch | |||||
type debugBatch struct { | |||||
label string | |||||
bch Batch | |||||
} | |||||
// For printing all operationgs to the console for debugging. | |||||
func NewDebugBatch(label string, bch Batch) debugBatch { | |||||
return debugBatch{ | |||||
label: label, | |||||
bch: bch, | |||||
} | |||||
} | |||||
// Implements Batch. | |||||
func (dbch debugBatch) Set(key, value []byte) { | |||||
fmt.Printf("%v.batch.Set(%X, %X)\n", dbch.label, key, value) | |||||
dbch.bch.Set(key, value) | |||||
} | |||||
// Implements Batch. | |||||
func (dbch debugBatch) Delete(key []byte) { | |||||
fmt.Printf("%v.batch.Delete(%X)\n", dbch.label, key) | |||||
dbch.bch.Delete(key) | |||||
} | |||||
// Implements Batch. | |||||
func (dbch debugBatch) Write() { | |||||
fmt.Printf("%v.batch.Write()\n", dbch.label) | |||||
dbch.bch.Write() | |||||
} | |||||
// Implements Batch. | |||||
func (dbch debugBatch) WriteSync() { | |||||
fmt.Printf("%v.batch.WriteSync()\n", dbch.label) | |||||
dbch.bch.WriteSync() | |||||
} |
@ -0,0 +1,263 @@ | |||||
package db | |||||
import ( | |||||
"bytes" | |||||
"fmt" | |||||
"sync" | |||||
) | |||||
// IteratePrefix is a convenience function for iterating over a key domain | |||||
// restricted by prefix. | |||||
func IteratePrefix(db DB, prefix []byte) Iterator { | |||||
var start, end []byte | |||||
if len(prefix) == 0 { | |||||
start = nil | |||||
end = nil | |||||
} else { | |||||
start = cp(prefix) | |||||
end = cpIncr(prefix) | |||||
} | |||||
return db.Iterator(start, end) | |||||
} | |||||
/* | |||||
TODO: Make test, maybe rename. | |||||
// Like IteratePrefix but the iterator strips the prefix from the keys. | |||||
func IteratePrefixStripped(db DB, prefix []byte) Iterator { | |||||
return newUnprefixIterator(prefix, IteratePrefix(db, prefix)) | |||||
} | |||||
*/ | |||||
//---------------------------------------- | |||||
// prefixDB | |||||
type prefixDB struct { | |||||
mtx sync.Mutex | |||||
prefix []byte | |||||
db DB | |||||
} | |||||
// NewPrefixDB lets you namespace multiple DBs within a single DB. | |||||
func NewPrefixDB(db DB, prefix []byte) *prefixDB { | |||||
return &prefixDB{ | |||||
prefix: prefix, | |||||
db: db, | |||||
} | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (pdb *prefixDB) Mutex() *sync.Mutex { | |||||
return &(pdb.mtx) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Get(key []byte) []byte { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
return pdb.db.Get(pdb.prefixed(key)) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Has(key []byte) bool { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
return pdb.db.Has(pdb.prefixed(key)) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Set(key []byte, value []byte) { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pdb.db.Set(pdb.prefixed(key), value) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) SetSync(key []byte, value []byte) { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pdb.db.SetSync(pdb.prefixed(key), value) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (pdb *prefixDB) SetNoLock(key []byte, value []byte) { | |||||
pdb.db.Set(pdb.prefixed(key), value) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (pdb *prefixDB) SetNoLockSync(key []byte, value []byte) { | |||||
pdb.db.SetSync(pdb.prefixed(key), value) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Delete(key []byte) { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pdb.db.Delete(pdb.prefixed(key)) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) DeleteSync(key []byte) { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pdb.db.DeleteSync(pdb.prefixed(key)) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (pdb *prefixDB) DeleteNoLock(key []byte) { | |||||
pdb.db.Delete(pdb.prefixed(key)) | |||||
} | |||||
// Implements atomicSetDeleter. | |||||
func (pdb *prefixDB) DeleteNoLockSync(key []byte) { | |||||
pdb.db.DeleteSync(pdb.prefixed(key)) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Iterator(start, end []byte) Iterator { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pstart := append(pdb.prefix, start...) | |||||
pend := []byte(nil) | |||||
if end != nil { | |||||
pend = append(pdb.prefix, end...) | |||||
} | |||||
return newUnprefixIterator( | |||||
pdb.prefix, | |||||
pdb.db.Iterator( | |||||
pstart, | |||||
pend, | |||||
), | |||||
) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) ReverseIterator(start, end []byte) Iterator { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pstart := []byte(nil) | |||||
if start != nil { | |||||
pstart = append(pdb.prefix, start...) | |||||
} | |||||
pend := []byte(nil) | |||||
if end != nil { | |||||
pend = append(pdb.prefix, end...) | |||||
} | |||||
return newUnprefixIterator( | |||||
pdb.prefix, | |||||
pdb.db.ReverseIterator( | |||||
pstart, | |||||
pend, | |||||
), | |||||
) | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) NewBatch() Batch { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
return &memBatch{pdb, nil} | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Close() { | |||||
pdb.mtx.Lock() | |||||
defer pdb.mtx.Unlock() | |||||
pdb.db.Close() | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Print() { | |||||
fmt.Printf("prefix: %X\n", pdb.prefix) | |||||
itr := pdb.Iterator(nil, nil) | |||||
defer itr.Close() | |||||
for ; itr.Valid(); itr.Next() { | |||||
key := itr.Key() | |||||
value := itr.Value() | |||||
fmt.Printf("[%X]:\t[%X]\n", key, value) | |||||
} | |||||
} | |||||
// Implements DB. | |||||
func (pdb *prefixDB) Stats() map[string]string { | |||||
stats := make(map[string]string) | |||||
stats["prefixdb.prefix.string"] = string(pdb.prefix) | |||||
stats["prefixdb.prefix.hex"] = fmt.Sprintf("%X", pdb.prefix) | |||||
source := pdb.db.Stats() | |||||
for key, value := range source { | |||||
stats["prefixdb.source."+key] = value | |||||
} | |||||
return stats | |||||
} | |||||
func (pdb *prefixDB) prefixed(key []byte) []byte { | |||||
return append(pdb.prefix, key...) | |||||
} | |||||
//---------------------------------------- | |||||
// Strips prefix while iterating from Iterator. | |||||
type unprefixIterator struct { | |||||
prefix []byte | |||||
source Iterator | |||||
} | |||||
func newUnprefixIterator(prefix []byte, source Iterator) unprefixIterator { | |||||
return unprefixIterator{ | |||||
prefix: prefix, | |||||
source: source, | |||||
} | |||||
} | |||||
func (itr unprefixIterator) Domain() (start []byte, end []byte) { | |||||
start, end = itr.source.Domain() | |||||
if len(start) > 0 { | |||||
start = stripPrefix(start, itr.prefix) | |||||
} | |||||
if len(end) > 0 { | |||||
end = stripPrefix(end, itr.prefix) | |||||
} | |||||
return | |||||
} | |||||
func (itr unprefixIterator) Valid() bool { | |||||
return itr.source.Valid() | |||||
} | |||||
func (itr unprefixIterator) Next() { | |||||
itr.source.Next() | |||||
} | |||||
func (itr unprefixIterator) Key() (key []byte) { | |||||
return stripPrefix(itr.source.Key(), itr.prefix) | |||||
} | |||||
func (itr unprefixIterator) Value() (value []byte) { | |||||
return itr.source.Value() | |||||
} | |||||
func (itr unprefixIterator) Close() { | |||||
itr.source.Close() | |||||
} | |||||
//---------------------------------------- | |||||
func stripPrefix(key []byte, prefix []byte) (stripped []byte) { | |||||
if len(key) < len(prefix) { | |||||
panic("should not happen") | |||||
} | |||||
if !bytes.Equal(key[:len(prefix)], prefix) { | |||||
panic("should not happne") | |||||
} | |||||
return key[len(prefix):] | |||||
} |
@ -0,0 +1,44 @@ | |||||
package db | |||||
import "testing" | |||||
func TestIteratePrefix(t *testing.T) { | |||||
db := NewMemDB() | |||||
// Under "key" prefix | |||||
db.Set(bz("key"), bz("value")) | |||||
db.Set(bz("key1"), bz("value1")) | |||||
db.Set(bz("key2"), bz("value2")) | |||||
db.Set(bz("key3"), bz("value3")) | |||||
db.Set(bz("something"), bz("else")) | |||||
db.Set(bz(""), bz("")) | |||||
db.Set(bz("k"), bz("val")) | |||||
db.Set(bz("ke"), bz("valu")) | |||||
db.Set(bz("kee"), bz("valuu")) | |||||
xitr := db.Iterator(nil, nil) | |||||
xitr.Key() | |||||
pdb := NewPrefixDB(db, bz("key")) | |||||
checkValue(t, pdb, bz("key"), nil) | |||||
checkValue(t, pdb, bz(""), bz("value")) | |||||
checkValue(t, pdb, bz("key1"), nil) | |||||
checkValue(t, pdb, bz("1"), bz("value1")) | |||||
checkValue(t, pdb, bz("key2"), nil) | |||||
checkValue(t, pdb, bz("2"), bz("value2")) | |||||
checkValue(t, pdb, bz("key3"), nil) | |||||
checkValue(t, pdb, bz("3"), bz("value3")) | |||||
checkValue(t, pdb, bz("something"), nil) | |||||
checkValue(t, pdb, bz("k"), nil) | |||||
checkValue(t, pdb, bz("ke"), nil) | |||||
checkValue(t, pdb, bz("kee"), nil) | |||||
itr := pdb.Iterator(nil, nil) | |||||
itr.Key() | |||||
checkItem(t, itr, bz(""), bz("value")) | |||||
checkNext(t, itr, true) | |||||
checkItem(t, itr, bz("1"), bz("value1")) | |||||
checkNext(t, itr, true) | |||||
checkItem(t, itr, bz("2"), bz("value2")) | |||||
checkNext(t, itr, true) | |||||
checkItem(t, itr, bz("3"), bz("value3")) | |||||
itr.Close() | |||||
} |
@ -1,3 +1,3 @@ | |||||
package version | package version | ||||
const Version = "0.7.1" | |||||
const Version = "0.8.1" |