@ -1,15 +1,148 @@ | |||
package common | |||
import "sync" | |||
func Parallel(tasks ...func()) { | |||
var wg sync.WaitGroup | |||
wg.Add(len(tasks)) | |||
for _, task := range tasks { | |||
go func(task func()) { | |||
task() | |||
wg.Done() | |||
}(task) | |||
} | |||
wg.Wait() | |||
import ( | |||
"sync/atomic" | |||
) | |||
//---------------------------------------- | |||
// Task | |||
// val: the value returned after task execution. | |||
// err: the error returned during task completion. | |||
// abort: tells Parallel to return, whether or not all tasks have completed. | |||
type Task func(i int) (val interface{}, err error, abort bool) | |||
type TaskResult struct { | |||
Value interface{} | |||
Error error | |||
} | |||
type TaskResultCh <-chan TaskResult | |||
type taskResultOK struct { | |||
TaskResult | |||
OK bool | |||
} | |||
type TaskResultSet struct { | |||
chz []TaskResultCh | |||
results []taskResultOK | |||
} | |||
func newTaskResultSet(chz []TaskResultCh) *TaskResultSet { | |||
return &TaskResultSet{ | |||
chz: chz, | |||
results: nil, | |||
} | |||
} | |||
func (trs *TaskResultSet) Channels() []TaskResultCh { | |||
return trs.chz | |||
} | |||
func (trs *TaskResultSet) LatestResult(index int) (TaskResult, bool) { | |||
if len(trs.results) <= index { | |||
return TaskResult{}, false | |||
} | |||
resultOK := trs.results[index] | |||
return resultOK.TaskResult, resultOK.OK | |||
} | |||
// NOTE: Not concurrency safe. | |||
func (trs *TaskResultSet) Reap() *TaskResultSet { | |||
if trs.results == nil { | |||
trs.results = make([]taskResultOK, len(trs.chz)) | |||
} | |||
for i := 0; i < len(trs.results); i++ { | |||
var trch = trs.chz[i] | |||
select { | |||
case result := <-trch: | |||
// Overwrite result. | |||
trs.results[i] = taskResultOK{ | |||
TaskResult: result, | |||
OK: true, | |||
} | |||
default: | |||
// Do nothing. | |||
} | |||
} | |||
return trs | |||
} | |||
// Returns the firstmost (by task index) error as | |||
// discovered by all previous Reap() calls. | |||
func (trs *TaskResultSet) FirstValue() interface{} { | |||
for _, result := range trs.results { | |||
if result.Value != nil { | |||
return result.Value | |||
} | |||
} | |||
return nil | |||
} | |||
// Returns the firstmost (by task index) error as | |||
// discovered by all previous Reap() calls. | |||
func (trs *TaskResultSet) FirstError() error { | |||
for _, result := range trs.results { | |||
if result.Error != nil { | |||
return result.Error | |||
} | |||
} | |||
return nil | |||
} | |||
//---------------------------------------- | |||
// Parallel | |||
// Run tasks in parallel, with ability to abort early. | |||
// Returns ok=false iff any of the tasks returned abort=true. | |||
// NOTE: Do not implement quit features here. Instead, provide convenient | |||
// concurrent quit-like primitives, passed implicitly via Task closures. (e.g. | |||
// it's not Parallel's concern how you quit/abort your tasks). | |||
func Parallel(tasks ...Task) (trs *TaskResultSet, ok bool) { | |||
var taskResultChz = make([]TaskResultCh, len(tasks)) // To return. | |||
var taskDoneCh = make(chan bool, len(tasks)) // A "wait group" channel, early abort if any true received. | |||
var numPanics = new(int32) // Keep track of panics to set ok=false later. | |||
ok = true // We will set it to false iff any tasks panic'd or returned abort. | |||
// Start all tasks in parallel in separate goroutines. | |||
// When the task is complete, it will appear in the | |||
// respective taskResultCh (associated by task index). | |||
for i, task := range tasks { | |||
var taskResultCh = make(chan TaskResult, 1) // Capacity for 1 result. | |||
taskResultChz[i] = taskResultCh | |||
go func(i int, task Task, taskResultCh chan TaskResult) { | |||
// Recovery | |||
defer func() { | |||
if pnk := recover(); pnk != nil { | |||
atomic.AddInt32(numPanics, 1) | |||
taskResultCh <- TaskResult{nil, ErrorWrap(pnk, "Panic in task")} | |||
taskDoneCh <- false | |||
} | |||
}() | |||
// Run the task. | |||
var val, err, abort = task(i) | |||
// Send val/err to taskResultCh. | |||
// NOTE: Below this line, nothing must panic/ | |||
taskResultCh <- TaskResult{val, err} | |||
// Decrement waitgroup. | |||
taskDoneCh <- abort | |||
}(i, task, taskResultCh) | |||
} | |||
// Wait until all tasks are done, or until abort. | |||
// DONE_LOOP: | |||
for i := 0; i < len(tasks); i++ { | |||
abort := <-taskDoneCh | |||
if abort { | |||
ok = false | |||
break | |||
} | |||
} | |||
// Ok is also false if there were any panics. | |||
// We must do this check here (after DONE_LOOP). | |||
ok = ok && (atomic.LoadInt32(numPanics) == 0) | |||
return newTaskResultSet(taskResultChz).Reap(), ok | |||
} |
@ -0,0 +1,152 @@ | |||
package common | |||
import ( | |||
"errors" | |||
"fmt" | |||
"sync/atomic" | |||
"testing" | |||
"time" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestParallel(t *testing.T) { | |||
// Create tasks. | |||
var counter = new(int32) | |||
var tasks = make([]Task, 100*1000) | |||
for i := 0; i < len(tasks); i++ { | |||
tasks[i] = func(i int) (res interface{}, err error, abort bool) { | |||
atomic.AddInt32(counter, 1) | |||
return -1 * i, nil, false | |||
} | |||
} | |||
// Run in parallel. | |||
var trs, ok = Parallel(tasks...) | |||
assert.True(t, ok) | |||
// Verify. | |||
assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") | |||
var failedTasks int | |||
for i := 0; i < len(tasks); i++ { | |||
taskResult, ok := trs.LatestResult(i) | |||
if !ok { | |||
assert.Fail(t, "Task #%v did not complete.", i) | |||
failedTasks++ | |||
} else if taskResult.Error != nil { | |||
assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) | |||
failedTasks++ | |||
} else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { | |||
assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) | |||
failedTasks++ | |||
} else { | |||
// Good! | |||
} | |||
} | |||
assert.Equal(t, failedTasks, 0, "No task should have failed") | |||
assert.Nil(t, trs.FirstError(), "There should be no errors") | |||
assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") | |||
} | |||
func TestParallelAbort(t *testing.T) { | |||
var flow1 = make(chan struct{}, 1) | |||
var flow2 = make(chan struct{}, 1) | |||
var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. | |||
var flow4 = make(chan struct{}, 1) | |||
// Create tasks. | |||
var tasks = []Task{ | |||
func(i int) (res interface{}, err error, abort bool) { | |||
assert.Equal(t, i, 0) | |||
flow1 <- struct{}{} | |||
return 0, nil, false | |||
}, | |||
func(i int) (res interface{}, err error, abort bool) { | |||
assert.Equal(t, i, 1) | |||
flow2 <- <-flow1 | |||
return 1, errors.New("some error"), false | |||
}, | |||
func(i int) (res interface{}, err error, abort bool) { | |||
assert.Equal(t, i, 2) | |||
flow3 <- <-flow2 | |||
return 2, nil, true | |||
}, | |||
func(i int) (res interface{}, err error, abort bool) { | |||
assert.Equal(t, i, 3) | |||
<-flow4 | |||
return 3, nil, false | |||
}, | |||
} | |||
// Run in parallel. | |||
var taskResultSet, ok = Parallel(tasks...) | |||
assert.False(t, ok, "ok should be false since we aborted task #2.") | |||
// Verify task #3. | |||
// Initially taskResultSet.chz[3] sends nothing since flow4 didn't send. | |||
waitTimeout(t, taskResultSet.chz[3], "Task #3") | |||
// Now let the last task (#3) complete after abort. | |||
flow4 <- <-flow3 | |||
// Verify task #0, #1, #2. | |||
checkResult(t, taskResultSet, 0, 0, nil, nil) | |||
checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) | |||
checkResult(t, taskResultSet, 2, 2, nil, nil) | |||
} | |||
func TestParallelRecover(t *testing.T) { | |||
// Create tasks. | |||
var tasks = []Task{ | |||
func(i int) (res interface{}, err error, abort bool) { | |||
return 0, nil, false | |||
}, | |||
func(i int) (res interface{}, err error, abort bool) { | |||
return 1, errors.New("some error"), false | |||
}, | |||
func(i int) (res interface{}, err error, abort bool) { | |||
panic(2) | |||
}, | |||
} | |||
// Run in parallel. | |||
var taskResultSet, ok = Parallel(tasks...) | |||
assert.False(t, ok, "ok should be false since we panic'd in task #2.") | |||
// Verify task #0, #1, #2. | |||
checkResult(t, taskResultSet, 0, 0, nil, nil) | |||
checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) | |||
checkResult(t, taskResultSet, 2, nil, nil, 2) | |||
} | |||
// Wait for result | |||
func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { | |||
taskResult, ok := taskResultSet.LatestResult(index) | |||
taskName := fmt.Sprintf("Task #%v", index) | |||
assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) | |||
assert.Equal(t, val, taskResult.Value, taskName) | |||
if err != nil { | |||
assert.Equal(t, err, taskResult.Error, taskName) | |||
} else if pnk != nil { | |||
assert.Equal(t, pnk, taskResult.Error.(Error).Cause(), taskName) | |||
} else { | |||
assert.Nil(t, taskResult.Error, taskName) | |||
} | |||
} | |||
// Wait for timeout (no result) | |||
func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { | |||
select { | |||
case _, ok := <-taskResultCh: | |||
if !ok { | |||
assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) | |||
} else { | |||
assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) | |||
} | |||
case <-time.After(1 * time.Second): // TODO use deterministic time? | |||
// Good! | |||
} | |||
} |
@ -0,0 +1,107 @@ | |||
package common | |||
import ( | |||
fmt "fmt" | |||
"testing" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestErrorPanic(t *testing.T) { | |||
type pnk struct { | |||
msg string | |||
} | |||
capturePanic := func() (err Error) { | |||
defer func() { | |||
if r := recover(); r != nil { | |||
err = ErrorWrap(r, "This is the message in ErrorWrap(r, message).") | |||
} | |||
return | |||
}() | |||
panic(pnk{"something"}) | |||
return nil | |||
} | |||
var err = capturePanic() | |||
assert.Equal(t, pnk{"something"}, err.Cause()) | |||
assert.Equal(t, pnk{"something"}, err.T()) | |||
assert.Equal(t, "This is the message in ErrorWrap(r, message).", err.Message()) | |||
assert.Equal(t, "Error{`This is the message in ErrorWrap(r, message).` (cause: {something})}", fmt.Sprintf("%v", err)) | |||
assert.Contains(t, fmt.Sprintf("%#v", err), "Message: This is the message in ErrorWrap(r, message).") | |||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||
} | |||
func TestErrorWrapSomething(t *testing.T) { | |||
var err = ErrorWrap("something", "formatter%v%v", 0, 1) | |||
assert.Equal(t, "something", err.Cause()) | |||
assert.Equal(t, "something", err.T()) | |||
assert.Equal(t, "formatter01", err.Message()) | |||
assert.Equal(t, "Error{`formatter01` (cause: something)}", fmt.Sprintf("%v", err)) | |||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||
} | |||
func TestErrorWrapNothing(t *testing.T) { | |||
var err = ErrorWrap(nil, "formatter%v%v", 0, 1) | |||
assert.Equal(t, nil, err.Cause()) | |||
assert.Equal(t, nil, err.T()) | |||
assert.Equal(t, "formatter01", err.Message()) | |||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||
} | |||
func TestErrorNewError(t *testing.T) { | |||
var err = NewError("formatter%v%v", 0, 1) | |||
assert.Equal(t, nil, err.Cause()) | |||
assert.Equal(t, "formatter%v%v", err.T()) | |||
assert.Equal(t, "formatter01", err.Message()) | |||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||
assert.NotContains(t, fmt.Sprintf("%#v", err), "Stack Trace") | |||
} | |||
func TestErrorNewErrorWithStacktrace(t *testing.T) { | |||
var err = NewError("formatter%v%v", 0, 1).Stacktrace() | |||
assert.Equal(t, nil, err.Cause()) | |||
assert.Equal(t, "formatter%v%v", err.T()) | |||
assert.Equal(t, "formatter01", err.Message()) | |||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||
assert.Contains(t, fmt.Sprintf("%#v", err), "Stack Trace:\n 0") | |||
} | |||
func TestErrorNewErrorWithTrace(t *testing.T) { | |||
var err = NewError("formatter%v%v", 0, 1) | |||
err.Trace("trace %v", 1) | |||
err.Trace("trace %v", 2) | |||
err.Trace("trace %v", 3) | |||
assert.Equal(t, nil, err.Cause()) | |||
assert.Equal(t, "formatter%v%v", err.T()) | |||
assert.Equal(t, "formatter01", err.Message()) | |||
assert.Equal(t, "Error{`formatter01`}", fmt.Sprintf("%v", err)) | |||
assert.Regexp(t, `Message: formatter01\n`, fmt.Sprintf("%#v", err)) | |||
dump := fmt.Sprintf("%#v", err) | |||
assert.NotContains(t, dump, "Stack Trace") | |||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 1`, dump) | |||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 2`, dump) | |||
assert.Regexp(t, `common/errors_test\.go:[0-9]+ - trace 3`, dump) | |||
} | |||
func TestErrorWrapError(t *testing.T) { | |||
var err1 error = NewError("my message") | |||
var err2 error = ErrorWrap(err1, "another message") | |||
assert.Equal(t, err1, err2) | |||
} |
@ -0,0 +1,29 @@ | |||
package common | |||
import "reflect" | |||
// Go lacks a simple and safe way to see if something is a typed nil. | |||
// See: | |||
// - https://dave.cheney.net/2017/08/09/typed-nils-in-go-2 | |||
// - https://groups.google.com/forum/#!topic/golang-nuts/wnH302gBa4I/discussion | |||
// - https://github.com/golang/go/issues/21538 | |||
func IsTypedNil(o interface{}) bool { | |||
rv := reflect.ValueOf(o) | |||
switch rv.Kind() { | |||
case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, reflect.Slice: | |||
return rv.IsNil() | |||
default: | |||
return false | |||
} | |||
} | |||
// Returns true if it has zero length. | |||
func IsEmpty(o interface{}) bool { | |||
rv := reflect.ValueOf(o) | |||
switch rv.Kind() { | |||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String: | |||
return rv.Len() == 0 | |||
default: | |||
return false | |||
} | |||
} |
@ -0,0 +1,194 @@ | |||
package db | |||
import ( | |||
"fmt" | |||
"testing" | |||
"github.com/stretchr/testify/assert" | |||
) | |||
func TestDBIteratorSingleKey(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
db.SetSync(bz("1"), bz("value_1")) | |||
itr := db.Iterator(nil, nil) | |||
checkValid(t, itr, true) | |||
checkNext(t, itr, false) | |||
checkValid(t, itr, false) | |||
checkNextPanics(t, itr) | |||
// Once invalid... | |||
checkInvalid(t, itr) | |||
}) | |||
} | |||
} | |||
func TestDBIteratorTwoKeys(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
db.SetSync(bz("1"), bz("value_1")) | |||
db.SetSync(bz("2"), bz("value_1")) | |||
{ // Fail by calling Next too much | |||
itr := db.Iterator(nil, nil) | |||
checkValid(t, itr, true) | |||
checkNext(t, itr, true) | |||
checkValid(t, itr, true) | |||
checkNext(t, itr, false) | |||
checkValid(t, itr, false) | |||
checkNextPanics(t, itr) | |||
// Once invalid... | |||
checkInvalid(t, itr) | |||
} | |||
}) | |||
} | |||
} | |||
func TestDBIteratorMany(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
keys := make([][]byte, 100) | |||
for i := 0; i < 100; i++ { | |||
keys[i] = []byte{byte(i)} | |||
} | |||
value := []byte{5} | |||
for _, k := range keys { | |||
db.Set(k, value) | |||
} | |||
itr := db.Iterator(nil, nil) | |||
defer itr.Close() | |||
for ; itr.Valid(); itr.Next() { | |||
assert.Equal(t, db.Get(itr.Key()), itr.Value()) | |||
} | |||
}) | |||
} | |||
} | |||
func TestDBIteratorEmpty(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
itr := db.Iterator(nil, nil) | |||
checkInvalid(t, itr) | |||
}) | |||
} | |||
} | |||
func TestDBIteratorEmptyBeginAfter(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
itr := db.Iterator(bz("1"), nil) | |||
checkInvalid(t, itr) | |||
}) | |||
} | |||
} | |||
func TestDBIteratorNonemptyBeginAfter(t *testing.T) { | |||
for backend := range backends { | |||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||
db := newTempDB(t, backend) | |||
db.SetSync(bz("1"), bz("value_1")) | |||
itr := db.Iterator(bz("2"), nil) | |||
checkInvalid(t, itr) | |||
}) | |||
} | |||
} | |||
func TestDBBatchWrite1(t *testing.T) { | |||
mdb := newMockDB() | |||
ddb := NewDebugDB(t.Name(), mdb) | |||
batch := ddb.NewBatch() | |||
batch.Set(bz("1"), bz("1")) | |||
batch.Set(bz("2"), bz("2")) | |||
batch.Delete(bz("3")) | |||
batch.Set(bz("4"), bz("4")) | |||
batch.Write() | |||
assert.Equal(t, 0, mdb.calls["Set"]) | |||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||
} | |||
func TestDBBatchWrite2(t *testing.T) { | |||
mdb := newMockDB() | |||
ddb := NewDebugDB(t.Name(), mdb) | |||
batch := ddb.NewBatch() | |||
batch.Set(bz("1"), bz("1")) | |||
batch.Set(bz("2"), bz("2")) | |||
batch.Set(bz("4"), bz("4")) | |||
batch.Delete(bz("3")) | |||
batch.Write() | |||
assert.Equal(t, 0, mdb.calls["Set"]) | |||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||
} | |||
func TestDBBatchWriteSync1(t *testing.T) { | |||
mdb := newMockDB() | |||
ddb := NewDebugDB(t.Name(), mdb) | |||
batch := ddb.NewBatch() | |||
batch.Set(bz("1"), bz("1")) | |||
batch.Set(bz("2"), bz("2")) | |||
batch.Delete(bz("3")) | |||
batch.Set(bz("4"), bz("4")) | |||
batch.WriteSync() | |||
assert.Equal(t, 0, mdb.calls["Set"]) | |||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||
assert.Equal(t, 2, mdb.calls["SetNoLock"]) | |||
assert.Equal(t, 1, mdb.calls["SetNoLockSync"]) | |||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||
assert.Equal(t, 1, mdb.calls["DeleteNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteNoLockSync"]) | |||
} | |||
func TestDBBatchWriteSync2(t *testing.T) { | |||
mdb := newMockDB() | |||
ddb := NewDebugDB(t.Name(), mdb) | |||
batch := ddb.NewBatch() | |||
batch.Set(bz("1"), bz("1")) | |||
batch.Set(bz("2"), bz("2")) | |||
batch.Set(bz("4"), bz("4")) | |||
batch.Delete(bz("3")) | |||
batch.WriteSync() | |||
assert.Equal(t, 0, mdb.calls["Set"]) | |||
assert.Equal(t, 0, mdb.calls["SetSync"]) | |||
assert.Equal(t, 3, mdb.calls["SetNoLock"]) | |||
assert.Equal(t, 0, mdb.calls["SetNoLockSync"]) | |||
assert.Equal(t, 0, mdb.calls["Delete"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteSync"]) | |||
assert.Equal(t, 0, mdb.calls["DeleteNoLock"]) | |||
assert.Equal(t, 1, mdb.calls["DeleteNoLockSync"]) | |||
} |
@ -0,0 +1,216 @@ | |||
package db | |||
import ( | |||
"fmt" | |||
"sync" | |||
) | |||
//---------------------------------------- | |||
// debugDB | |||
type debugDB struct { | |||
label string | |||
db DB | |||
} | |||
// For printing all operationgs to the console for debugging. | |||
func NewDebugDB(label string, db DB) debugDB { | |||
return debugDB{ | |||
label: label, | |||
db: db, | |||
} | |||
} | |||
// Implements atomicSetDeleter. | |||
func (ddb debugDB) Mutex() *sync.Mutex { return nil } | |||
// Implements DB. | |||
func (ddb debugDB) Get(key []byte) (value []byte) { | |||
defer fmt.Printf("%v.Get(%X) %X\n", ddb.label, key, value) | |||
value = ddb.db.Get(key) | |||
return | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Has(key []byte) (has bool) { | |||
defer fmt.Printf("%v.Has(%X) %v\n", ddb.label, key, has) | |||
return ddb.db.Has(key) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Set(key []byte, value []byte) { | |||
fmt.Printf("%v.Set(%X, %X)\n", ddb.label, key, value) | |||
ddb.db.Set(key, value) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) SetSync(key []byte, value []byte) { | |||
fmt.Printf("%v.SetSync(%X, %X)\n", ddb.label, key, value) | |||
ddb.db.SetSync(key, value) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (ddb debugDB) SetNoLock(key []byte, value []byte) { | |||
fmt.Printf("%v.SetNoLock(%X, %X)\n", ddb.label, key, value) | |||
ddb.db.Set(key, value) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (ddb debugDB) SetNoLockSync(key []byte, value []byte) { | |||
fmt.Printf("%v.SetNoLockSync(%X, %X)\n", ddb.label, key, value) | |||
ddb.db.SetSync(key, value) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Delete(key []byte) { | |||
fmt.Printf("%v.Delete(%X)\n", ddb.label, key) | |||
ddb.db.Delete(key) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) DeleteSync(key []byte) { | |||
fmt.Printf("%v.DeleteSync(%X)\n", ddb.label, key) | |||
ddb.db.DeleteSync(key) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (ddb debugDB) DeleteNoLock(key []byte) { | |||
fmt.Printf("%v.DeleteNoLock(%X)\n", ddb.label, key) | |||
ddb.db.Delete(key) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (ddb debugDB) DeleteNoLockSync(key []byte) { | |||
fmt.Printf("%v.DeleteNoLockSync(%X)\n", ddb.label, key) | |||
ddb.db.DeleteSync(key) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Iterator(start, end []byte) Iterator { | |||
fmt.Printf("%v.Iterator(%X, %X)\n", ddb.label, start, end) | |||
return NewDebugIterator(ddb.label, ddb.db.Iterator(start, end)) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) ReverseIterator(start, end []byte) Iterator { | |||
fmt.Printf("%v.ReverseIterator(%X, %X)\n", ddb.label, start, end) | |||
return NewDebugIterator(ddb.label, ddb.db.ReverseIterator(start, end)) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) NewBatch() Batch { | |||
fmt.Printf("%v.NewBatch()\n", ddb.label) | |||
return NewDebugBatch(ddb.label, ddb.db.NewBatch()) | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Close() { | |||
fmt.Printf("%v.Close()\n", ddb.label) | |||
ddb.db.Close() | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Print() { | |||
ddb.db.Print() | |||
} | |||
// Implements DB. | |||
func (ddb debugDB) Stats() map[string]string { | |||
return ddb.db.Stats() | |||
} | |||
//---------------------------------------- | |||
// debugIterator | |||
type debugIterator struct { | |||
label string | |||
itr Iterator | |||
} | |||
// For printing all operationgs to the console for debugging. | |||
func NewDebugIterator(label string, itr Iterator) debugIterator { | |||
return debugIterator{ | |||
label: label, | |||
itr: itr, | |||
} | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Domain() (start []byte, end []byte) { | |||
defer fmt.Printf("%v.itr.Domain() (%X,%X)\n", ditr.label, start, end) | |||
start, end = ditr.itr.Domain() | |||
return | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Valid() (ok bool) { | |||
defer fmt.Printf("%v.itr.Valid() %v\n", ditr.label, ok) | |||
ok = ditr.itr.Valid() | |||
return | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Next() { | |||
fmt.Printf("%v.itr.Next()\n", ditr.label) | |||
ditr.itr.Next() | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Key() (key []byte) { | |||
fmt.Printf("%v.itr.Key() %X\n", ditr.label, key) | |||
key = ditr.itr.Key() | |||
return | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Value() (value []byte) { | |||
fmt.Printf("%v.itr.Value() %X\n", ditr.label, value) | |||
value = ditr.itr.Value() | |||
return | |||
} | |||
// Implements Iterator. | |||
func (ditr debugIterator) Close() { | |||
fmt.Printf("%v.itr.Close()\n", ditr.label) | |||
ditr.itr.Close() | |||
} | |||
//---------------------------------------- | |||
// debugBatch | |||
type debugBatch struct { | |||
label string | |||
bch Batch | |||
} | |||
// For printing all operationgs to the console for debugging. | |||
func NewDebugBatch(label string, bch Batch) debugBatch { | |||
return debugBatch{ | |||
label: label, | |||
bch: bch, | |||
} | |||
} | |||
// Implements Batch. | |||
func (dbch debugBatch) Set(key, value []byte) { | |||
fmt.Printf("%v.batch.Set(%X, %X)\n", dbch.label, key, value) | |||
dbch.bch.Set(key, value) | |||
} | |||
// Implements Batch. | |||
func (dbch debugBatch) Delete(key []byte) { | |||
fmt.Printf("%v.batch.Delete(%X)\n", dbch.label, key) | |||
dbch.bch.Delete(key) | |||
} | |||
// Implements Batch. | |||
func (dbch debugBatch) Write() { | |||
fmt.Printf("%v.batch.Write()\n", dbch.label) | |||
dbch.bch.Write() | |||
} | |||
// Implements Batch. | |||
func (dbch debugBatch) WriteSync() { | |||
fmt.Printf("%v.batch.WriteSync()\n", dbch.label) | |||
dbch.bch.WriteSync() | |||
} |
@ -0,0 +1,263 @@ | |||
package db | |||
import ( | |||
"bytes" | |||
"fmt" | |||
"sync" | |||
) | |||
// IteratePrefix is a convenience function for iterating over a key domain | |||
// restricted by prefix. | |||
func IteratePrefix(db DB, prefix []byte) Iterator { | |||
var start, end []byte | |||
if len(prefix) == 0 { | |||
start = nil | |||
end = nil | |||
} else { | |||
start = cp(prefix) | |||
end = cpIncr(prefix) | |||
} | |||
return db.Iterator(start, end) | |||
} | |||
/* | |||
TODO: Make test, maybe rename. | |||
// Like IteratePrefix but the iterator strips the prefix from the keys. | |||
func IteratePrefixStripped(db DB, prefix []byte) Iterator { | |||
return newUnprefixIterator(prefix, IteratePrefix(db, prefix)) | |||
} | |||
*/ | |||
//---------------------------------------- | |||
// prefixDB | |||
type prefixDB struct { | |||
mtx sync.Mutex | |||
prefix []byte | |||
db DB | |||
} | |||
// NewPrefixDB lets you namespace multiple DBs within a single DB. | |||
func NewPrefixDB(db DB, prefix []byte) *prefixDB { | |||
return &prefixDB{ | |||
prefix: prefix, | |||
db: db, | |||
} | |||
} | |||
// Implements atomicSetDeleter. | |||
func (pdb *prefixDB) Mutex() *sync.Mutex { | |||
return &(pdb.mtx) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Get(key []byte) []byte { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
return pdb.db.Get(pdb.prefixed(key)) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Has(key []byte) bool { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
return pdb.db.Has(pdb.prefixed(key)) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Set(key []byte, value []byte) { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pdb.db.Set(pdb.prefixed(key), value) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) SetSync(key []byte, value []byte) { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pdb.db.SetSync(pdb.prefixed(key), value) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (pdb *prefixDB) SetNoLock(key []byte, value []byte) { | |||
pdb.db.Set(pdb.prefixed(key), value) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (pdb *prefixDB) SetNoLockSync(key []byte, value []byte) { | |||
pdb.db.SetSync(pdb.prefixed(key), value) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Delete(key []byte) { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pdb.db.Delete(pdb.prefixed(key)) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) DeleteSync(key []byte) { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pdb.db.DeleteSync(pdb.prefixed(key)) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (pdb *prefixDB) DeleteNoLock(key []byte) { | |||
pdb.db.Delete(pdb.prefixed(key)) | |||
} | |||
// Implements atomicSetDeleter. | |||
func (pdb *prefixDB) DeleteNoLockSync(key []byte) { | |||
pdb.db.DeleteSync(pdb.prefixed(key)) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Iterator(start, end []byte) Iterator { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pstart := append(pdb.prefix, start...) | |||
pend := []byte(nil) | |||
if end != nil { | |||
pend = append(pdb.prefix, end...) | |||
} | |||
return newUnprefixIterator( | |||
pdb.prefix, | |||
pdb.db.Iterator( | |||
pstart, | |||
pend, | |||
), | |||
) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) ReverseIterator(start, end []byte) Iterator { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pstart := []byte(nil) | |||
if start != nil { | |||
pstart = append(pdb.prefix, start...) | |||
} | |||
pend := []byte(nil) | |||
if end != nil { | |||
pend = append(pdb.prefix, end...) | |||
} | |||
return newUnprefixIterator( | |||
pdb.prefix, | |||
pdb.db.ReverseIterator( | |||
pstart, | |||
pend, | |||
), | |||
) | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) NewBatch() Batch { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
return &memBatch{pdb, nil} | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Close() { | |||
pdb.mtx.Lock() | |||
defer pdb.mtx.Unlock() | |||
pdb.db.Close() | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Print() { | |||
fmt.Printf("prefix: %X\n", pdb.prefix) | |||
itr := pdb.Iterator(nil, nil) | |||
defer itr.Close() | |||
for ; itr.Valid(); itr.Next() { | |||
key := itr.Key() | |||
value := itr.Value() | |||
fmt.Printf("[%X]:\t[%X]\n", key, value) | |||
} | |||
} | |||
// Implements DB. | |||
func (pdb *prefixDB) Stats() map[string]string { | |||
stats := make(map[string]string) | |||
stats["prefixdb.prefix.string"] = string(pdb.prefix) | |||
stats["prefixdb.prefix.hex"] = fmt.Sprintf("%X", pdb.prefix) | |||
source := pdb.db.Stats() | |||
for key, value := range source { | |||
stats["prefixdb.source."+key] = value | |||
} | |||
return stats | |||
} | |||
func (pdb *prefixDB) prefixed(key []byte) []byte { | |||
return append(pdb.prefix, key...) | |||
} | |||
//---------------------------------------- | |||
// Strips prefix while iterating from Iterator. | |||
type unprefixIterator struct { | |||
prefix []byte | |||
source Iterator | |||
} | |||
func newUnprefixIterator(prefix []byte, source Iterator) unprefixIterator { | |||
return unprefixIterator{ | |||
prefix: prefix, | |||
source: source, | |||
} | |||
} | |||
func (itr unprefixIterator) Domain() (start []byte, end []byte) { | |||
start, end = itr.source.Domain() | |||
if len(start) > 0 { | |||
start = stripPrefix(start, itr.prefix) | |||
} | |||
if len(end) > 0 { | |||
end = stripPrefix(end, itr.prefix) | |||
} | |||
return | |||
} | |||
func (itr unprefixIterator) Valid() bool { | |||
return itr.source.Valid() | |||
} | |||
func (itr unprefixIterator) Next() { | |||
itr.source.Next() | |||
} | |||
func (itr unprefixIterator) Key() (key []byte) { | |||
return stripPrefix(itr.source.Key(), itr.prefix) | |||
} | |||
func (itr unprefixIterator) Value() (value []byte) { | |||
return itr.source.Value() | |||
} | |||
func (itr unprefixIterator) Close() { | |||
itr.source.Close() | |||
} | |||
//---------------------------------------- | |||
func stripPrefix(key []byte, prefix []byte) (stripped []byte) { | |||
if len(key) < len(prefix) { | |||
panic("should not happen") | |||
} | |||
if !bytes.Equal(key[:len(prefix)], prefix) { | |||
panic("should not happne") | |||
} | |||
return key[len(prefix):] | |||
} |
@ -0,0 +1,44 @@ | |||
package db | |||
import "testing" | |||
func TestIteratePrefix(t *testing.T) { | |||
db := NewMemDB() | |||
// Under "key" prefix | |||
db.Set(bz("key"), bz("value")) | |||
db.Set(bz("key1"), bz("value1")) | |||
db.Set(bz("key2"), bz("value2")) | |||
db.Set(bz("key3"), bz("value3")) | |||
db.Set(bz("something"), bz("else")) | |||
db.Set(bz(""), bz("")) | |||
db.Set(bz("k"), bz("val")) | |||
db.Set(bz("ke"), bz("valu")) | |||
db.Set(bz("kee"), bz("valuu")) | |||
xitr := db.Iterator(nil, nil) | |||
xitr.Key() | |||
pdb := NewPrefixDB(db, bz("key")) | |||
checkValue(t, pdb, bz("key"), nil) | |||
checkValue(t, pdb, bz(""), bz("value")) | |||
checkValue(t, pdb, bz("key1"), nil) | |||
checkValue(t, pdb, bz("1"), bz("value1")) | |||
checkValue(t, pdb, bz("key2"), nil) | |||
checkValue(t, pdb, bz("2"), bz("value2")) | |||
checkValue(t, pdb, bz("key3"), nil) | |||
checkValue(t, pdb, bz("3"), bz("value3")) | |||
checkValue(t, pdb, bz("something"), nil) | |||
checkValue(t, pdb, bz("k"), nil) | |||
checkValue(t, pdb, bz("ke"), nil) | |||
checkValue(t, pdb, bz("kee"), nil) | |||
itr := pdb.Iterator(nil, nil) | |||
itr.Key() | |||
checkItem(t, itr, bz(""), bz("value")) | |||
checkNext(t, itr, true) | |||
checkItem(t, itr, bz("1"), bz("value1")) | |||
checkNext(t, itr, true) | |||
checkItem(t, itr, bz("2"), bz("value2")) | |||
checkNext(t, itr, true) | |||
checkItem(t, itr, bz("3"), bz("value3")) | |||
itr.Close() | |||
} |
@ -1,3 +1,3 @@ | |||
package version | |||
const Version = "0.7.1" | |||
const Version = "0.8.1" |