Browse Source

Nil keys are OK, deprecate BeginningKey/EndingKey (#101)

* Nil keys are OK, deprecate BeginningKey/EndingKey
pull/1842/head
Jae Kwon 7 years ago
committed by GitHub
parent
commit
4ce8448d7f
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 400 additions and 332 deletions
  1. +82
    -13
      db/backend_test.go
  2. +73
    -59
      db/c_level_db.go
  3. +29
    -5
      db/common_test.go
  4. +28
    -23
      db/fsdb.go
  5. +79
    -84
      db/go_level_db.go
  6. +59
    -65
      db/mem_db.go
  7. +0
    -48
      db/mem_db_test.go
  8. +23
    -22
      db/types.go
  9. +22
    -7
      db/util.go
  10. +5
    -6
      db/util_test.go

+ 82
- 13
db/backend_test.go View File

@ -21,6 +21,13 @@ func testBackendGetSetDelete(t *testing.T, backend string) {
defer dir.Close()
db := NewDB("testdb", backend, dirname)
// A nonexistent key should return nil, even if the key is empty.
require.Nil(t, db.Get([]byte("")))
// A nonexistent key should return nil, even if the key is nil.
require.Nil(t, db.Get(nil))
// A nonexistent key should return nil.
key := []byte("abc")
require.Nil(t, db.Get(key))
@ -55,27 +62,89 @@ func withDB(t *testing.T, creator dbCreator, fn func(DB)) {
}
func TestBackendsNilKeys(t *testing.T) {
// test all backends
// test all backends.
// nil keys are treated as the empty key for most operations.
for dbType, creator := range backends {
withDB(t, creator, func(db DB) {
panicMsg := "expecting %s.%s to panic"
assert.Panics(t, func() { db.Get(nil) }, panicMsg, dbType, "get")
assert.Panics(t, func() { db.Has(nil) }, panicMsg, dbType, "has")
assert.Panics(t, func() { db.Set(nil, []byte("abc")) }, panicMsg, dbType, "set")
assert.Panics(t, func() { db.SetSync(nil, []byte("abc")) }, panicMsg, dbType, "setsync")
assert.Panics(t, func() { db.Delete(nil) }, panicMsg, dbType, "delete")
assert.Panics(t, func() { db.DeleteSync(nil) }, panicMsg, dbType, "deletesync")
t.Run(fmt.Sprintf("Testing %s", dbType), func(t *testing.T) {
expect := func(key, value []byte) {
if len(key) == 0 { // nil or empty
assert.Equal(t, db.Get(nil), db.Get([]byte("")))
assert.Equal(t, db.Has(nil), db.Has([]byte("")))
}
assert.Equal(t, db.Get(key), value)
assert.Equal(t, db.Has(key), value != nil)
}
// Not set
expect(nil, nil)
// Set nil value
db.Set(nil, nil)
expect(nil, []byte(""))
// Set empty value
db.Set(nil, []byte(""))
expect(nil, []byte(""))
// Set nil, Delete nil
db.Set(nil, []byte("abc"))
expect(nil, []byte("abc"))
db.Delete(nil)
expect(nil, nil)
// Set nil, Delete empty
db.Set(nil, []byte("abc"))
expect(nil, []byte("abc"))
db.Delete([]byte(""))
expect(nil, nil)
// Set empty, Delete nil
db.Set([]byte(""), []byte("abc"))
expect(nil, []byte("abc"))
db.Delete(nil)
expect(nil, nil)
// Set empty, Delete empty
db.Set([]byte(""), []byte("abc"))
expect(nil, []byte("abc"))
db.Delete([]byte(""))
expect(nil, nil)
// SetSync nil, DeleteSync nil
db.SetSync(nil, []byte("abc"))
expect(nil, []byte("abc"))
db.DeleteSync(nil)
expect(nil, nil)
// SetSync nil, DeleteSync empty
db.SetSync(nil, []byte("abc"))
expect(nil, []byte("abc"))
db.DeleteSync([]byte(""))
expect(nil, nil)
// SetSync empty, DeleteSync nil
db.SetSync([]byte(""), []byte("abc"))
expect(nil, []byte("abc"))
db.DeleteSync(nil)
expect(nil, nil)
// SetSync empty, DeleteSync empty
db.SetSync([]byte(""), []byte("abc"))
expect(nil, []byte("abc"))
db.DeleteSync([]byte(""))
expect(nil, nil)
})
})
}
}
func TestGoLevelDBBackendStr(t *testing.T) {
name := cmn.Fmt("test_%x", cmn.RandStr(12))
db := NewDB(name, LevelDBBackendStr, "")
db := NewDB(name, GoLevelDBBackendStr, "")
defer cleanupDBDir("", name)
if _, ok := backends[CLevelDBBackendStr]; !ok {
_, ok := db.(*GoLevelDB)
assert.True(t, ok)
}
_, ok := db.(*GoLevelDB)
assert.True(t, ok)
}

+ 73
- 59
db/c_level_db.go View File

@ -51,7 +51,7 @@ func NewCLevelDB(name string, dir string) (*CLevelDB, error) {
}
func (db *CLevelDB) Get(key []byte) []byte {
panicNilKey(key)
key = nonNilBytes(key)
res, err := db.db.Get(db.ro, key)
if err != nil {
panic(err)
@ -60,12 +60,12 @@ func (db *CLevelDB) Get(key []byte) []byte {
}
func (db *CLevelDB) Has(key []byte) bool {
panicNilKey(key)
panic("not implemented yet")
return db.Get(key) != nil
}
func (db *CLevelDB) Set(key []byte, value []byte) {
panicNilKey(key)
key = nonNilBytes(key)
value = nonNilBytes(value)
err := db.db.Put(db.wo, key, value)
if err != nil {
panic(err)
@ -73,7 +73,8 @@ func (db *CLevelDB) Set(key []byte, value []byte) {
}
func (db *CLevelDB) SetSync(key []byte, value []byte) {
panicNilKey(key)
key = nonNilBytes(key)
value = nonNilBytes(value)
err := db.db.Put(db.woSync, key, value)
if err != nil {
panic(err)
@ -81,7 +82,7 @@ func (db *CLevelDB) SetSync(key []byte, value []byte) {
}
func (db *CLevelDB) Delete(key []byte) {
panicNilKey(key)
key = nonNilBytes(key)
err := db.db.Delete(db.wo, key)
if err != nil {
panic(err)
@ -89,7 +90,7 @@ func (db *CLevelDB) Delete(key []byte) {
}
func (db *CLevelDB) DeleteSync(key []byte) {
panicNilKey(key)
key = nonNilBytes(key)
err := db.db.Delete(db.woSync, key)
if err != nil {
panic(err)
@ -108,7 +109,7 @@ func (db *CLevelDB) Close() {
}
func (db *CLevelDB) Print() {
itr := db.Iterator(BeginningKey(), EndingKey())
itr := db.Iterator(nil, nil)
defer itr.Close()
for ; itr.Valid(); itr.Next() {
key := itr.Key()
@ -159,94 +160,107 @@ func (mBatch *cLevelDBBatch) Write() {
//----------------------------------------
// Iterator
// NOTE This is almost identical to db/go_level_db.Iterator
// Before creating a third version, refactor.
func (db *CLevelDB) Iterator(start, end []byte) Iterator {
itr := db.db.NewIterator(db.ro)
return newCLevelDBIterator(itr, start, end)
return newCLevelDBIterator(itr, start, end, false)
}
func (db *CLevelDB) ReverseIterator(start, end []byte) Iterator {
// XXX
return nil
panic("not implemented yet") // XXX
}
var _ Iterator = (*cLevelDBIterator)(nil)
type cLevelDBIterator struct {
itr *levigo.Iterator
source *levigo.Iterator
start, end []byte
invalid bool
isReverse bool
isInvalid bool
}
func newCLevelDBIterator(itr *levigo.Iterator, start, end []byte) *cLevelDBIterator {
if len(start) > 0 {
itr.Seek(start)
func newCLevelDBIterator(source *levigo.Iterator, start, end []byte, isReverse bool) *cLevelDBIterator {
if isReverse {
panic("not implemented yet") // XXX
}
if start != nil {
source.Seek(start)
} else {
itr.SeekToFirst()
source.SeekToFirst()
}
return &cLevelDBIterator{
itr: itr,
start: start,
end: end,
source: source,
start: start,
end: end,
isReverse: isReverse,
isInvalid: false,
}
}
func (c *cLevelDBIterator) Domain() ([]byte, []byte) {
return c.start, c.end
func (itr *cLevelDBIterator) Domain() ([]byte, []byte) {
return itr.start, itr.end
}
func (c *cLevelDBIterator) Valid() bool {
c.assertNoError()
if c.invalid {
func (itr *cLevelDBIterator) Valid() bool {
// Once invalid, forever invalid.
if itr.isInvalid {
return false
}
c.invalid = !c.itr.Valid()
return !c.invalid
}
func (c *cLevelDBIterator) Key() []byte {
if !c.Valid() {
panic("cLevelDBIterator Key() called when invalid")
// Panic on DB error. No way to recover.
itr.assertNoError()
// If source is invalid, invalid.
if !itr.source.Valid() {
itr.isInvalid = true
return false
}
return c.itr.Key()
}
func (c *cLevelDBIterator) Value() []byte {
if !c.Valid() {
panic("cLevelDBIterator Value() called when invalid")
// If key is end or past it, invalid.
var end = itr.end
var key = itr.source.Key()
if end != nil && bytes.Compare(end, key) <= 0 {
itr.isInvalid = true
return false
}
return c.itr.Value()
// Valid
return true
}
func (c *cLevelDBIterator) Next() {
if !c.Valid() {
panic("cLevelDBIterator Next() called when invalid")
}
c.itr.Next()
c.checkEndKey() // if we've exceeded the range, we're now invalid
func (itr *cLevelDBIterator) Key() []byte {
itr.assertNoError()
itr.assertIsValid()
return itr.source.Key()
}
// levigo has no upper bound when iterating, so need to check ourselves
func (c *cLevelDBIterator) checkEndKey() {
if !c.itr.Valid() {
c.invalid = true
return
}
func (itr *cLevelDBIterator) Value() []byte {
itr.assertNoError()
itr.assertIsValid()
return itr.source.Value()
}
key := c.itr.Key()
if c.end != nil && bytes.Compare(key, c.end) > 0 {
c.invalid = true
}
func (itr *cLevelDBIterator) Next() {
itr.assertNoError()
itr.assertIsValid()
itr.source.Next()
}
func (c *cLevelDBIterator) Close() {
c.itr.Close()
func (itr *cLevelDBIterator) Close() {
itr.source.Close()
}
func (c *cLevelDBIterator) assertNoError() {
if err := c.itr.GetError(); err != nil {
func (itr *cLevelDBIterator) assertNoError() {
if err := itr.source.GetError(); err != nil {
panic(err)
}
}
func (itr cLevelDBIterator) assertIsValid() {
if !itr.Valid() {
panic("cLevelDBIterator is invalid")
}
}

+ 29
- 5
db/common_test.go View File

@ -57,7 +57,7 @@ func TestDBIteratorSingleKey(t *testing.T) {
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
db := newTempDB(t, backend)
db.SetSync(bz("1"), bz("value_1"))
itr := db.Iterator(BeginningKey(), EndingKey())
itr := db.Iterator(nil, nil)
checkValid(t, itr, true)
checkNext(t, itr, false)
@ -78,7 +78,7 @@ func TestDBIteratorTwoKeys(t *testing.T) {
db.SetSync(bz("2"), bz("value_1"))
{ // Fail by calling Next too much
itr := db.Iterator(BeginningKey(), EndingKey())
itr := db.Iterator(nil, nil)
checkValid(t, itr, true)
checkNext(t, itr, true)
@ -96,11 +96,35 @@ func TestDBIteratorTwoKeys(t *testing.T) {
}
}
func TestDBIteratorMany(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
db := newTempDB(t, backend)
keys := make([][]byte, 100)
for i := 0; i < 100; i++ {
keys[i] = []byte{byte(i)}
}
value := []byte{5}
for _, k := range keys {
db.Set(k, value)
}
itr := db.Iterator(nil, nil)
defer itr.Close()
for ; itr.Valid(); itr.Next() {
assert.Equal(t, db.Get(itr.Key()), itr.Value())
}
})
}
}
func TestDBIteratorEmpty(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
db := newTempDB(t, backend)
itr := db.Iterator(BeginningKey(), EndingKey())
itr := db.Iterator(nil, nil)
checkInvalid(t, itr)
})
@ -111,7 +135,7 @@ func TestDBIteratorEmptyBeginAfter(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
db := newTempDB(t, backend)
itr := db.Iterator(bz("1"), EndingKey())
itr := db.Iterator(bz("1"), nil)
checkInvalid(t, itr)
})
@ -123,7 +147,7 @@ func TestDBIteratorNonemptyBeginAfter(t *testing.T) {
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) {
db := newTempDB(t, backend)
db.SetSync(bz("1"), bz("value_1"))
itr := db.Iterator(bz("2"), EndingKey())
itr := db.Iterator(bz("2"), nil)
checkInvalid(t, itr)
})


+ 28
- 23
db/fsdb.go View File

@ -47,7 +47,7 @@ func NewFSDB(dir string) *FSDB {
func (db *FSDB) Get(key []byte) []byte {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
key = escapeKey(key)
path := db.nameToPath(key)
value, err := read(path)
@ -62,7 +62,7 @@ func (db *FSDB) Get(key []byte) []byte {
func (db *FSDB) Has(key []byte) bool {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
key = escapeKey(key)
path := db.nameToPath(key)
return cmn.FileExists(path)
@ -71,7 +71,6 @@ func (db *FSDB) Has(key []byte) bool {
func (db *FSDB) Set(key []byte, value []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.SetNoLock(key, value)
}
@ -79,17 +78,14 @@ func (db *FSDB) Set(key []byte, value []byte) {
func (db *FSDB) SetSync(key []byte, value []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.SetNoLock(key, value)
}
// NOTE: Implements atomicSetDeleter.
func (db *FSDB) SetNoLock(key []byte, value []byte) {
panicNilKey(key)
if value == nil {
value = []byte{}
}
key = escapeKey(key)
value = nonNilBytes(value)
path := db.nameToPath(key)
err := write(path, value)
if err != nil {
@ -100,7 +96,6 @@ func (db *FSDB) SetNoLock(key []byte, value []byte) {
func (db *FSDB) Delete(key []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.DeleteNoLock(key)
}
@ -108,14 +103,13 @@ func (db *FSDB) Delete(key []byte) {
func (db *FSDB) DeleteSync(key []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.DeleteNoLock(key)
}
// NOTE: Implements atomicSetDeleter.
func (db *FSDB) DeleteNoLock(key []byte) {
panicNilKey(key)
key = escapeKey(key)
path := db.nameToPath(key)
err := remove(path)
if os.IsNotExist(err) {
@ -157,8 +151,6 @@ func (db *FSDB) Mutex() *sync.Mutex {
}
func (db *FSDB) Iterator(start, end []byte) Iterator {
it := newMemDBIterator(db, start, end)
db.mtx.Lock()
defer db.mtx.Unlock()
@ -169,13 +161,11 @@ func (db *FSDB) Iterator(start, end []byte) Iterator {
panic(errors.Wrapf(err, "Listing keys in %s", db.dir))
}
sort.Strings(keys)
it.keys = keys
return it
return newMemDBIterator(db, keys, start, end)
}
func (db *FSDB) ReverseIterator(start, end []byte) Iterator {
// XXX
return nil
panic("not implemented yet") // XXX
}
func (db *FSDB) nameToPath(name []byte) string {
@ -221,8 +211,7 @@ func remove(path string) error {
return os.Remove(path)
}
// List files of a path.
// Paths will NOT include dir as the prefix.
// List keys in a directory, stripping of escape sequences and dir portions.
// CONTRACT: returns os errors directly without wrapping.
func list(dirPath string, start, end []byte) ([]string, error) {
dir, err := os.Open(dirPath)
@ -235,15 +224,31 @@ func list(dirPath string, start, end []byte) ([]string, error) {
if err != nil {
return nil, err
}
var paths []string
var keys []string
for _, name := range names {
n, err := url.PathUnescape(name)
if err != nil {
return nil, fmt.Errorf("Failed to unescape %s while listing", name)
}
if IsKeyInDomain([]byte(n), start, end) {
paths = append(paths, n)
key := unescapeKey([]byte(n))
if IsKeyInDomain(key, start, end, false) {
keys = append(keys, string(key))
}
}
return paths, nil
return keys, nil
}
// To support empty or nil keys, while the file system doesn't allow empty
// filenames.
func escapeKey(key []byte) []byte {
return []byte("k_" + string(key))
}
func unescapeKey(escKey []byte) []byte {
if len(escKey) < 2 {
panic(fmt.Sprintf("Invalid esc key: %x", escKey))
}
if string(escKey[:2]) != "k_" {
panic(fmt.Sprintf("Invalid esc key: %x", escKey))
}
return escKey[2:]
}

+ 79
- 84
db/go_level_db.go View File

@ -1,6 +1,7 @@
package db
import (
"bytes"
"fmt"
"path/filepath"
@ -8,7 +9,6 @@ import (
"github.com/syndtr/goleveldb/leveldb/errors"
"github.com/syndtr/goleveldb/leveldb/iterator"
"github.com/syndtr/goleveldb/leveldb/opt"
"github.com/syndtr/goleveldb/leveldb/util"
. "github.com/tendermint/tmlibs/common"
)
@ -40,33 +40,25 @@ func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) {
}
func (db *GoLevelDB) Get(key []byte) []byte {
panicNilKey(key)
key = nonNilBytes(key)
res, err := db.db.Get(key, nil)
if err != nil {
if err == errors.ErrNotFound {
return nil
} else {
PanicCrisis(err)
panic(err)
}
}
return res
}
func (db *GoLevelDB) Has(key []byte) bool {
panicNilKey(key)
_, err := db.db.Get(key, nil)
if err != nil {
if err == errors.ErrNotFound {
return false
} else {
PanicCrisis(err)
}
}
return true
return db.Get(key) != nil
}
func (db *GoLevelDB) Set(key []byte, value []byte) {
panicNilKey(key)
key = nonNilBytes(key)
value = nonNilBytes(value)
err := db.db.Put(key, value, nil)
if err != nil {
PanicCrisis(err)
@ -74,7 +66,8 @@ func (db *GoLevelDB) Set(key []byte, value []byte) {
}
func (db *GoLevelDB) SetSync(key []byte, value []byte) {
panicNilKey(key)
key = nonNilBytes(key)
value = nonNilBytes(value)
err := db.db.Put(key, value, &opt.WriteOptions{Sync: true})
if err != nil {
PanicCrisis(err)
@ -82,7 +75,7 @@ func (db *GoLevelDB) SetSync(key []byte, value []byte) {
}
func (db *GoLevelDB) Delete(key []byte) {
panicNilKey(key)
key = nonNilBytes(key)
err := db.db.Delete(key, nil)
if err != nil {
PanicCrisis(err)
@ -90,7 +83,7 @@ func (db *GoLevelDB) Delete(key []byte) {
}
func (db *GoLevelDB) DeleteSync(key []byte) {
panicNilKey(key)
key = nonNilBytes(key)
err := db.db.Delete(key, &opt.WriteOptions{Sync: true})
if err != nil {
PanicCrisis(err)
@ -169,102 +162,104 @@ func (mBatch *goLevelDBBatch) Write() {
//----------------------------------------
// Iterator
// NOTE This is almost identical to db/c_level_db.Iterator
// Before creating a third version, refactor.
type goLevelDBIterator struct {
source iterator.Iterator
start []byte
end []byte
isReverse bool
isInvalid bool
}
var _ Iterator = (*goLevelDBIterator)(nil)
// https://godoc.org/github.com/syndtr/goleveldb/leveldb#DB.NewIterator
// A nil Range.Start is treated as a key before all keys in the DB.
// And a nil Range.Limit is treated as a key after all keys in the DB.
func goLevelDBIterRange(start, end []byte) *util.Range {
// XXX: what if start == nil ?
if len(start) == 0 {
start = nil
func newGoLevelDBIterator(source iterator.Iterator, start, end []byte, isReverse bool) *goLevelDBIterator {
if isReverse {
panic("not implemented yet") // XXX
}
return &util.Range{
Start: start,
Limit: end,
source.Seek(start)
return &goLevelDBIterator{
source: source,
start: start,
end: end,
isReverse: isReverse,
isInvalid: false,
}
}
func (db *GoLevelDB) Iterator(start, end []byte) Iterator {
itrRange := goLevelDBIterRange(start, end)
itr := db.db.NewIterator(itrRange, nil)
itr.Seek(start) // if we don't call this the itr is never valid (?!)
return &goLevelDBIterator{
source: itr,
start: start,
end: end,
}
itr := db.db.NewIterator(nil, nil)
return newGoLevelDBIterator(itr, start, end, false)
}
func (db *GoLevelDB) ReverseIterator(start, end []byte) Iterator {
// XXX
return nil
panic("not implemented yet") // XXX
}
var _ Iterator = (*goLevelDBIterator)(nil)
type goLevelDBIterator struct {
source iterator.Iterator
invalid bool
start, end []byte
func (itr *goLevelDBIterator) Domain() ([]byte, []byte) {
return itr.start, itr.end
}
func (it *goLevelDBIterator) Domain() ([]byte, []byte) {
return it.start, it.end
}
func (itr *goLevelDBIterator) Valid() bool {
// Key returns a copy of the current key.
func (it *goLevelDBIterator) Key() []byte {
if !it.Valid() {
panic("goLevelDBIterator Key() called when invalid")
// Once invalid, forever invalid.
if itr.isInvalid {
return false
}
key := it.source.Key()
k := make([]byte, len(key))
copy(k, key)
return k
}
// Panic on DB error. No way to recover.
itr.assertNoError()
// Value returns a copy of the current value.
func (it *goLevelDBIterator) Value() []byte {
if !it.Valid() {
panic("goLevelDBIterator Value() called when invalid")
// If source is invalid, invalid.
if !itr.source.Valid() {
itr.isInvalid = true
return false
}
val := it.source.Value()
v := make([]byte, len(val))
copy(v, val)
return v
}
func (it *goLevelDBIterator) Valid() bool {
it.assertNoError()
if it.invalid {
// If key is end or past it, invalid.
var end = itr.end
var key = itr.source.Key()
if end != nil && bytes.Compare(end, key) <= 0 {
itr.isInvalid = true
return false
}
it.invalid = !it.source.Valid()
return !it.invalid
// Valid
return true
}
func (it *goLevelDBIterator) Next() {
if !it.Valid() {
panic("goLevelDBIterator Next() called when invalid")
}
it.source.Next()
func (itr *goLevelDBIterator) Key() []byte {
itr.assertNoError()
itr.assertIsValid()
return itr.source.Key()
}
func (it *goLevelDBIterator) Prev() {
if !it.Valid() {
panic("goLevelDBIterator Prev() called when invalid")
}
it.source.Prev()
func (itr *goLevelDBIterator) Value() []byte {
itr.assertNoError()
itr.assertIsValid()
return itr.source.Value()
}
func (itr *goLevelDBIterator) Next() {
itr.assertNoError()
itr.assertIsValid()
itr.source.Next()
}
func (it *goLevelDBIterator) Close() {
it.source.Release()
func (itr *goLevelDBIterator) Close() {
itr.source.Release()
}
func (it *goLevelDBIterator) assertNoError() {
if err := it.source.Error(); err != nil {
func (itr *goLevelDBIterator) assertNoError() {
if err := itr.source.Error(); err != nil {
panic(err)
}
}
func (itr goLevelDBIterator) assertIsValid() {
if !itr.Valid() {
panic("goLevelDBIterator is invalid")
}
}

+ 59
- 65
db/mem_db.go View File

@ -29,14 +29,16 @@ func NewMemDB() *MemDB {
func (db *MemDB) Get(key []byte) []byte {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
key = nonNilBytes(key)
return db.db[string(key)]
}
func (db *MemDB) Has(key []byte) bool {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
key = nonNilBytes(key)
_, ok := db.db[string(key)]
return ok
}
@ -44,43 +46,43 @@ func (db *MemDB) Has(key []byte) bool {
func (db *MemDB) Set(key []byte, value []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.SetNoLock(key, value)
}
func (db *MemDB) SetSync(key []byte, value []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
db.SetNoLock(key, value)
}
// NOTE: Implements atomicSetDeleter
func (db *MemDB) SetNoLock(key []byte, value []byte) {
if value == nil {
value = []byte{}
}
panicNilKey(key)
key = nonNilBytes(key)
value = nonNilBytes(value)
db.db[string(key)] = value
}
func (db *MemDB) Delete(key []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
delete(db.db, string(key))
db.DeleteNoLock(key)
}
func (db *MemDB) DeleteSync(key []byte) {
db.mtx.Lock()
defer db.mtx.Unlock()
panicNilKey(key)
delete(db.db, string(key))
db.DeleteNoLock(key)
}
// NOTE: Implements atomicSetDeleter
func (db *MemDB) DeleteNoLock(key []byte) {
panicNilKey(key)
key = nonNilBytes(key)
delete(db.db, string(key))
}
@ -125,100 +127,92 @@ func (db *MemDB) Mutex() *sync.Mutex {
//----------------------------------------
func (db *MemDB) Iterator(start, end []byte) Iterator {
it := newMemDBIterator(db, start, end)
db.mtx.Lock()
defer db.mtx.Unlock()
// We need a copy of all of the keys.
// Not the best, but probably not a bottleneck depending.
it.keys = db.getSortedKeys(start, end)
return it
keys := db.getSortedKeys(start, end, false)
return newMemDBIterator(db, keys, start, end)
}
func (db *MemDB) ReverseIterator(start, end []byte) Iterator {
it := newMemDBIterator(db, start, end)
db.mtx.Lock()
defer db.mtx.Unlock()
// We need a copy of all of the keys.
// Not the best, but probably not a bottleneck depending.
it.keys = db.getSortedKeys(end, start)
// reverse the order
l := len(it.keys) - 1
for i, v := range it.keys {
it.keys[i] = it.keys[l-i]
it.keys[l-i] = v
}
return nil
keys := db.getSortedKeys(end, start, true)
return newMemDBIterator(db, keys, start, end)
}
func (db *MemDB) getSortedKeys(start, end []byte) []string {
func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string {
keys := []string{}
for key, _ := range db.db {
if IsKeyInDomain([]byte(key), start, end) {
if IsKeyInDomain([]byte(key), start, end, false) {
keys = append(keys, key)
}
}
sort.Strings(keys)
if reverse {
nkeys := len(keys)
for i := 0; i < nkeys/2; i++ {
keys[i] = keys[nkeys-i-1]
}
}
return keys
}
var _ Iterator = (*memDBIterator)(nil)
// We need a copy of all of the keys.
// Not the best, but probably not a bottleneck depending.
type memDBIterator struct {
cur int
keys []string
db DB
start, end []byte
db DB
cur int
keys []string
start []byte
end []byte
}
func newMemDBIterator(db DB, start, end []byte) *memDBIterator {
// Keys is expected to be in reverse order for reverse iterators.
func newMemDBIterator(db DB, keys []string, start, end []byte) *memDBIterator {
return &memDBIterator{
db: db,
cur: 0,
keys: keys,
start: start,
end: end,
}
}
func (it *memDBIterator) Domain() ([]byte, []byte) {
return it.start, it.end
func (itr *memDBIterator) Domain() ([]byte, []byte) {
return itr.start, itr.end
}
func (it *memDBIterator) Valid() bool {
return 0 <= it.cur && it.cur < len(it.keys)
func (itr *memDBIterator) Valid() bool {
return 0 <= itr.cur && itr.cur < len(itr.keys)
}
func (it *memDBIterator) Next() {
if !it.Valid() {
panic("memDBIterator Next() called when invalid")
}
it.cur++
func (itr *memDBIterator) Next() {
itr.assertIsValid()
itr.cur++
}
func (it *memDBIterator) Prev() {
if !it.Valid() {
panic("memDBIterator Next() called when invalid")
}
it.cur--
func (itr *memDBIterator) Key() []byte {
itr.assertIsValid()
return []byte(itr.keys[itr.cur])
}
func (it *memDBIterator) Key() []byte {
if !it.Valid() {
panic("memDBIterator Key() called when invalid")
}
return []byte(it.keys[it.cur])
func (itr *memDBIterator) Value() []byte {
itr.assertIsValid()
key := []byte(itr.keys[itr.cur])
return itr.db.Get(key)
}
func (it *memDBIterator) Value() []byte {
if !it.Valid() {
panic("memDBIterator Value() called when invalid")
}
return it.db.Get(it.Key())
func (itr *memDBIterator) Close() {
itr.keys = nil
itr.db = nil
}
func (it *memDBIterator) Close() {
it.db = nil
it.keys = nil
func (itr *memDBIterator) assertIsValid() {
if !itr.Valid() {
panic("memDBIterator is invalid")
}
}

+ 0
- 48
db/mem_db_test.go View File

@ -1,48 +0,0 @@
package db
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMemDBIterator(t *testing.T) {
db := NewMemDB()
keys := make([][]byte, 100)
for i := 0; i < 100; i++ {
keys[i] = []byte{byte(i)}
}
value := []byte{5}
for _, k := range keys {
db.Set(k, value)
}
iter := db.Iterator(BeginningKey(), EndingKey())
i := 0
for ; iter.Valid(); iter.Next() {
assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key")
i += 1
}
assert.Equal(t, i, len(db.db), "iterator didnt cover whole db")
}
func TestMemDBClose(t *testing.T) {
db := NewMemDB()
copyDB := func(orig map[string][]byte) map[string][]byte {
copy := make(map[string][]byte)
for k, v := range orig {
copy[k] = v
}
return copy
}
k, v := []byte("foo"), []byte("bar")
db.Set(k, v)
require.Equal(t, db.Get(k), v, "expecting a successful get")
copyBefore := copyDB(db.db)
db.Close()
require.Equal(t, db.Get(k), v, "Close is a noop, expecting a successful get")
copyAfter := copyDB(db.db)
require.Equal(t, copyBefore, copyAfter, "Close is a noop and shouldn't modify any internal data")
}

+ 23
- 22
db/types.go View File

@ -2,31 +2,39 @@ package db
type DB interface {
// Get returns nil iff key doesn't exist. Panics on nil key.
// Get returns nil iff key doesn't exist.
// A nil key is interpreted as an empty byteslice.
Get([]byte) []byte
// Has checks if a key exists. Panics on nil key.
// Has checks if a key exists.
// A nil key is interpreted as an empty byteslice.
Has(key []byte) bool
// Set sets the key. Panics on nil key.
// Set sets the key.
// A nil key is interpreted as an empty byteslice.
Set([]byte, []byte)
SetSync([]byte, []byte)
// Delete deletes the key. Panics on nil key.
// Delete deletes the key.
// A nil key is interpreted as an empty byteslice.
Delete([]byte)
DeleteSync([]byte)
// Iterator over a domain of keys in ascending order. End is exclusive.
// Iterate over a domain of keys in ascending order. End is exclusive.
// Start must be less than end, or the Iterator is invalid.
// A nil start is interpreted as an empty byteslice.
// If end is nil, iterates up to the last item (inclusive).
// CONTRACT: No writes may happen within a domain while an iterator exists over it.
Iterator(start, end []byte) Iterator
// Iterator over a domain of keys in descending order. End is exclusive.
// Iterate over a domain of keys in descending order. End is exclusive.
// Start must be greater than end, or the Iterator is invalid.
// If start is nil, iterates from the last/greatest item (inclusive).
// If end is nil, iterates up to the first/least item (iclusive).
// CONTRACT: No writes may happen within a domain while an iterator exists over it.
ReverseIterator(start, end []byte) Iterator
// Releases the connection.
// Closes the connection.
Close()
// Creates a batch for atomic updates.
@ -54,16 +62,6 @@ type SetDeleter interface {
//----------------------------------------
// BeginningKey is the smallest key.
func BeginningKey() []byte {
return []byte{}
}
// EndingKey is the largest key.
func EndingKey() []byte {
return nil
}
/*
Usage:
@ -107,7 +105,7 @@ type Iterator interface {
// If Valid returns false, this method will panic.
Value() []byte
// Release deallocates the given Iterator.
// Close releases the Iterator.
Close()
}
@ -116,9 +114,12 @@ func bz(s string) []byte {
return []byte(s)
}
// All DB funcs should panic on nil key.
func panicNilKey(key []byte) {
if key == nil {
panic("nil key")
// We defensively turn nil keys or values into []byte{} for
// most operations.
func nonNilBytes(bz []byte) []byte {
if bz == nil {
return []byte{}
} else {
return bz
}
}

+ 22
- 7
db/util.go View File

@ -7,8 +7,8 @@ import (
func IteratePrefix(db DB, prefix []byte) Iterator {
var start, end []byte
if len(prefix) == 0 {
start = BeginningKey()
end = EndingKey()
start = nil
end = nil
} else {
start = cp(prefix)
end = cpIncr(prefix)
@ -35,11 +35,26 @@ func cpIncr(bz []byte) (ret []byte) {
ret[i] = byte(0x00)
}
}
return EndingKey()
return nil
}
func IsKeyInDomain(key, start, end []byte) bool {
leftCondition := bytes.Equal(start, BeginningKey()) || bytes.Compare(key, start) >= 0
rightCondition := bytes.Equal(end, EndingKey()) || bytes.Compare(key, end) < 0
return leftCondition && rightCondition
// See DB interface documentation for more information.
func IsKeyInDomain(key, start, end []byte, isReverse bool) bool {
if !isReverse {
if bytes.Compare(key, start) < 0 {
return false
}
if end != nil && bytes.Compare(end, key) <= 0 {
return false
}
return true
} else {
if start != nil && bytes.Compare(start, key) < 0 {
return false
}
if end != nil && bytes.Compare(key, end) <= 0 {
return false
}
return true
}
}

+ 5
- 6
db/util_test.go View File

@ -5,7 +5,7 @@ import (
"testing"
)
// empty iterator for empty db
// Empty iterator for empty db.
func TestPrefixIteratorNoMatchNil(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {
@ -17,7 +17,7 @@ func TestPrefixIteratorNoMatchNil(t *testing.T) {
}
}
// empty iterator for db populated after iterator created
// Empty iterator for db populated after iterator created.
func TestPrefixIteratorNoMatch1(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {
@ -30,7 +30,7 @@ func TestPrefixIteratorNoMatch1(t *testing.T) {
}
}
// empty iterator for prefix starting above db entry
// Empty iterator for prefix starting after db entry.
func TestPrefixIteratorNoMatch2(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {
@ -38,13 +38,12 @@ func TestPrefixIteratorNoMatch2(t *testing.T) {
db.SetSync(bz("3"), bz("value_3"))
itr := IteratePrefix(db, []byte("4"))
// Once invalid...
checkInvalid(t, itr)
})
}
}
// iterator with single val for db with single val, starting from that val
// Iterator with single val for db with single val, starting from that val.
func TestPrefixIteratorMatch1(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {
@ -62,7 +61,7 @@ func TestPrefixIteratorMatch1(t *testing.T) {
}
}
// iterator with prefix iterates over everything with same prefix
// Iterator with prefix iterates over everything with same prefix.
func TestPrefixIteratorMatches1N(t *testing.T) {
for backend, _ := range backends {
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {


Loading…
Cancel
Save