From 4ce8448d7fcf92b040046f894474ce2f7e779b67 Mon Sep 17 00:00:00 2001 From: Jae Kwon Date: Sun, 17 Dec 2017 13:11:28 -0800 Subject: [PATCH] Nil keys are OK, deprecate BeginningKey/EndingKey (#101) * Nil keys are OK, deprecate BeginningKey/EndingKey --- db/backend_test.go | 95 ++++++++++++++++++++++---- db/c_level_db.go | 132 ++++++++++++++++++++---------------- db/common_test.go | 34 ++++++++-- db/fsdb.go | 51 +++++++------- db/go_level_db.go | 163 ++++++++++++++++++++++----------------------- db/mem_db.go | 124 ++++++++++++++++------------------ db/mem_db_test.go | 48 ------------- db/types.go | 45 +++++++------ db/util.go | 29 ++++++-- db/util_test.go | 11 ++- 10 files changed, 400 insertions(+), 332 deletions(-) delete mode 100644 db/mem_db_test.go diff --git a/db/backend_test.go b/db/backend_test.go index 3362fecf6..e103843dc 100644 --- a/db/backend_test.go +++ b/db/backend_test.go @@ -21,6 +21,13 @@ func testBackendGetSetDelete(t *testing.T, backend string) { defer dir.Close() db := NewDB("testdb", backend, dirname) + // A nonexistent key should return nil, even if the key is empty. + require.Nil(t, db.Get([]byte(""))) + + // A nonexistent key should return nil, even if the key is nil. + require.Nil(t, db.Get(nil)) + + // A nonexistent key should return nil. key := []byte("abc") require.Nil(t, db.Get(key)) @@ -55,27 +62,89 @@ func withDB(t *testing.T, creator dbCreator, fn func(DB)) { } func TestBackendsNilKeys(t *testing.T) { - // test all backends + // test all backends. + // nil keys are treated as the empty key for most operations. for dbType, creator := range backends { withDB(t, creator, func(db DB) { - panicMsg := "expecting %s.%s to panic" - assert.Panics(t, func() { db.Get(nil) }, panicMsg, dbType, "get") - assert.Panics(t, func() { db.Has(nil) }, panicMsg, dbType, "has") - assert.Panics(t, func() { db.Set(nil, []byte("abc")) }, panicMsg, dbType, "set") - assert.Panics(t, func() { db.SetSync(nil, []byte("abc")) }, panicMsg, dbType, "setsync") - assert.Panics(t, func() { db.Delete(nil) }, panicMsg, dbType, "delete") - assert.Panics(t, func() { db.DeleteSync(nil) }, panicMsg, dbType, "deletesync") + t.Run(fmt.Sprintf("Testing %s", dbType), func(t *testing.T) { + + expect := func(key, value []byte) { + if len(key) == 0 { // nil or empty + assert.Equal(t, db.Get(nil), db.Get([]byte(""))) + assert.Equal(t, db.Has(nil), db.Has([]byte(""))) + } + assert.Equal(t, db.Get(key), value) + assert.Equal(t, db.Has(key), value != nil) + } + + // Not set + expect(nil, nil) + + // Set nil value + db.Set(nil, nil) + expect(nil, []byte("")) + + // Set empty value + db.Set(nil, []byte("")) + expect(nil, []byte("")) + + // Set nil, Delete nil + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set nil, Delete empty + db.Set(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // Set empty, Delete nil + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete(nil) + expect(nil, nil) + + // Set empty, Delete empty + db.Set([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.Delete([]byte("")) + expect(nil, nil) + + // SetSync nil, DeleteSync nil + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync nil, DeleteSync empty + db.SetSync(nil, []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + + // SetSync empty, DeleteSync nil + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync(nil) + expect(nil, nil) + + // SetSync empty, DeleteSync empty + db.SetSync([]byte(""), []byte("abc")) + expect(nil, []byte("abc")) + db.DeleteSync([]byte("")) + expect(nil, nil) + }) }) } } func TestGoLevelDBBackendStr(t *testing.T) { name := cmn.Fmt("test_%x", cmn.RandStr(12)) - db := NewDB(name, LevelDBBackendStr, "") + db := NewDB(name, GoLevelDBBackendStr, "") defer cleanupDBDir("", name) - if _, ok := backends[CLevelDBBackendStr]; !ok { - _, ok := db.(*GoLevelDB) - assert.True(t, ok) - } + _, ok := db.(*GoLevelDB) + assert.True(t, ok) } diff --git a/db/c_level_db.go b/db/c_level_db.go index 60198d84c..c9f8d419b 100644 --- a/db/c_level_db.go +++ b/db/c_level_db.go @@ -51,7 +51,7 @@ func NewCLevelDB(name string, dir string) (*CLevelDB, error) { } func (db *CLevelDB) Get(key []byte) []byte { - panicNilKey(key) + key = nonNilBytes(key) res, err := db.db.Get(db.ro, key) if err != nil { panic(err) @@ -60,12 +60,12 @@ func (db *CLevelDB) Get(key []byte) []byte { } func (db *CLevelDB) Has(key []byte) bool { - panicNilKey(key) - panic("not implemented yet") + return db.Get(key) != nil } func (db *CLevelDB) Set(key []byte, value []byte) { - panicNilKey(key) + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(db.wo, key, value) if err != nil { panic(err) @@ -73,7 +73,8 @@ func (db *CLevelDB) Set(key []byte, value []byte) { } func (db *CLevelDB) SetSync(key []byte, value []byte) { - panicNilKey(key) + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(db.woSync, key, value) if err != nil { panic(err) @@ -81,7 +82,7 @@ func (db *CLevelDB) SetSync(key []byte, value []byte) { } func (db *CLevelDB) Delete(key []byte) { - panicNilKey(key) + key = nonNilBytes(key) err := db.db.Delete(db.wo, key) if err != nil { panic(err) @@ -89,7 +90,7 @@ func (db *CLevelDB) Delete(key []byte) { } func (db *CLevelDB) DeleteSync(key []byte) { - panicNilKey(key) + key = nonNilBytes(key) err := db.db.Delete(db.woSync, key) if err != nil { panic(err) @@ -108,7 +109,7 @@ func (db *CLevelDB) Close() { } func (db *CLevelDB) Print() { - itr := db.Iterator(BeginningKey(), EndingKey()) + itr := db.Iterator(nil, nil) defer itr.Close() for ; itr.Valid(); itr.Next() { key := itr.Key() @@ -159,94 +160,107 @@ func (mBatch *cLevelDBBatch) Write() { //---------------------------------------- // Iterator +// NOTE This is almost identical to db/go_level_db.Iterator +// Before creating a third version, refactor. func (db *CLevelDB) Iterator(start, end []byte) Iterator { itr := db.db.NewIterator(db.ro) - return newCLevelDBIterator(itr, start, end) + return newCLevelDBIterator(itr, start, end, false) } func (db *CLevelDB) ReverseIterator(start, end []byte) Iterator { - // XXX - return nil + panic("not implemented yet") // XXX } var _ Iterator = (*cLevelDBIterator)(nil) type cLevelDBIterator struct { - itr *levigo.Iterator + source *levigo.Iterator start, end []byte - invalid bool + isReverse bool + isInvalid bool } -func newCLevelDBIterator(itr *levigo.Iterator, start, end []byte) *cLevelDBIterator { - - if len(start) > 0 { - itr.Seek(start) +func newCLevelDBIterator(source *levigo.Iterator, start, end []byte, isReverse bool) *cLevelDBIterator { + if isReverse { + panic("not implemented yet") // XXX + } + if start != nil { + source.Seek(start) } else { - itr.SeekToFirst() + source.SeekToFirst() } - return &cLevelDBIterator{ - itr: itr, - start: start, - end: end, + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, } } -func (c *cLevelDBIterator) Domain() ([]byte, []byte) { - return c.start, c.end +func (itr *cLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end } -func (c *cLevelDBIterator) Valid() bool { - c.assertNoError() - if c.invalid { +func (itr *cLevelDBIterator) Valid() bool { + + // Once invalid, forever invalid. + if itr.isInvalid { return false } - c.invalid = !c.itr.Valid() - return !c.invalid -} -func (c *cLevelDBIterator) Key() []byte { - if !c.Valid() { - panic("cLevelDBIterator Key() called when invalid") + // Panic on DB error. No way to recover. + itr.assertNoError() + + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false } - return c.itr.Key() -} -func (c *cLevelDBIterator) Value() []byte { - if !c.Valid() { - panic("cLevelDBIterator Value() called when invalid") + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true + return false } - return c.itr.Value() + + // Valid + return true } -func (c *cLevelDBIterator) Next() { - if !c.Valid() { - panic("cLevelDBIterator Next() called when invalid") - } - c.itr.Next() - c.checkEndKey() // if we've exceeded the range, we're now invalid +func (itr *cLevelDBIterator) Key() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Key() } -// levigo has no upper bound when iterating, so need to check ourselves -func (c *cLevelDBIterator) checkEndKey() { - if !c.itr.Valid() { - c.invalid = true - return - } +func (itr *cLevelDBIterator) Value() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Value() +} - key := c.itr.Key() - if c.end != nil && bytes.Compare(key, c.end) > 0 { - c.invalid = true - } +func (itr *cLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + itr.source.Next() } -func (c *cLevelDBIterator) Close() { - c.itr.Close() +func (itr *cLevelDBIterator) Close() { + itr.source.Close() } -func (c *cLevelDBIterator) assertNoError() { - if err := c.itr.GetError(); err != nil { +func (itr *cLevelDBIterator) assertNoError() { + if err := itr.source.GetError(); err != nil { panic(err) } } + +func (itr cLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("cLevelDBIterator is invalid") + } +} diff --git a/db/common_test.go b/db/common_test.go index 6b3009795..2a5d01818 100644 --- a/db/common_test.go +++ b/db/common_test.go @@ -57,7 +57,7 @@ func TestDBIteratorSingleKey(t *testing.T) { t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) db.SetSync(bz("1"), bz("value_1")) - itr := db.Iterator(BeginningKey(), EndingKey()) + itr := db.Iterator(nil, nil) checkValid(t, itr, true) checkNext(t, itr, false) @@ -78,7 +78,7 @@ func TestDBIteratorTwoKeys(t *testing.T) { db.SetSync(bz("2"), bz("value_1")) { // Fail by calling Next too much - itr := db.Iterator(BeginningKey(), EndingKey()) + itr := db.Iterator(nil, nil) checkValid(t, itr, true) checkNext(t, itr, true) @@ -96,11 +96,35 @@ func TestDBIteratorTwoKeys(t *testing.T) { } } +func TestDBIteratorMany(t *testing.T) { + for backend, _ := range backends { + t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { + db := newTempDB(t, backend) + + keys := make([][]byte, 100) + for i := 0; i < 100; i++ { + keys[i] = []byte{byte(i)} + } + + value := []byte{5} + for _, k := range keys { + db.Set(k, value) + } + + itr := db.Iterator(nil, nil) + defer itr.Close() + for ; itr.Valid(); itr.Next() { + assert.Equal(t, db.Get(itr.Key()), itr.Value()) + } + }) + } +} + func TestDBIteratorEmpty(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) - itr := db.Iterator(BeginningKey(), EndingKey()) + itr := db.Iterator(nil, nil) checkInvalid(t, itr) }) @@ -111,7 +135,7 @@ func TestDBIteratorEmptyBeginAfter(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) - itr := db.Iterator(bz("1"), EndingKey()) + itr := db.Iterator(bz("1"), nil) checkInvalid(t, itr) }) @@ -123,7 +147,7 @@ func TestDBIteratorNonemptyBeginAfter(t *testing.T) { t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { db := newTempDB(t, backend) db.SetSync(bz("1"), bz("value_1")) - itr := db.Iterator(bz("2"), EndingKey()) + itr := db.Iterator(bz("2"), nil) checkInvalid(t, itr) }) diff --git a/db/fsdb.go b/db/fsdb.go index 056cc3982..45c3231f6 100644 --- a/db/fsdb.go +++ b/db/fsdb.go @@ -47,7 +47,7 @@ func NewFSDB(dir string) *FSDB { func (db *FSDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + key = escapeKey(key) path := db.nameToPath(key) value, err := read(path) @@ -62,7 +62,7 @@ func (db *FSDB) Get(key []byte) []byte { func (db *FSDB) Has(key []byte) bool { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + key = escapeKey(key) path := db.nameToPath(key) return cmn.FileExists(path) @@ -71,7 +71,6 @@ func (db *FSDB) Has(key []byte) bool { func (db *FSDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) db.SetNoLock(key, value) } @@ -79,17 +78,14 @@ func (db *FSDB) Set(key []byte, value []byte) { func (db *FSDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) db.SetNoLock(key, value) } // NOTE: Implements atomicSetDeleter. func (db *FSDB) SetNoLock(key []byte, value []byte) { - panicNilKey(key) - if value == nil { - value = []byte{} - } + key = escapeKey(key) + value = nonNilBytes(value) path := db.nameToPath(key) err := write(path, value) if err != nil { @@ -100,7 +96,6 @@ func (db *FSDB) SetNoLock(key []byte, value []byte) { func (db *FSDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) db.DeleteNoLock(key) } @@ -108,14 +103,13 @@ func (db *FSDB) Delete(key []byte) { func (db *FSDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) db.DeleteNoLock(key) } // NOTE: Implements atomicSetDeleter. func (db *FSDB) DeleteNoLock(key []byte) { - panicNilKey(key) + key = escapeKey(key) path := db.nameToPath(key) err := remove(path) if os.IsNotExist(err) { @@ -157,8 +151,6 @@ func (db *FSDB) Mutex() *sync.Mutex { } func (db *FSDB) Iterator(start, end []byte) Iterator { - it := newMemDBIterator(db, start, end) - db.mtx.Lock() defer db.mtx.Unlock() @@ -169,13 +161,11 @@ func (db *FSDB) Iterator(start, end []byte) Iterator { panic(errors.Wrapf(err, "Listing keys in %s", db.dir)) } sort.Strings(keys) - it.keys = keys - return it + return newMemDBIterator(db, keys, start, end) } func (db *FSDB) ReverseIterator(start, end []byte) Iterator { - // XXX - return nil + panic("not implemented yet") // XXX } func (db *FSDB) nameToPath(name []byte) string { @@ -221,8 +211,7 @@ func remove(path string) error { return os.Remove(path) } -// List files of a path. -// Paths will NOT include dir as the prefix. +// List keys in a directory, stripping of escape sequences and dir portions. // CONTRACT: returns os errors directly without wrapping. func list(dirPath string, start, end []byte) ([]string, error) { dir, err := os.Open(dirPath) @@ -235,15 +224,31 @@ func list(dirPath string, start, end []byte) ([]string, error) { if err != nil { return nil, err } - var paths []string + var keys []string for _, name := range names { n, err := url.PathUnescape(name) if err != nil { return nil, fmt.Errorf("Failed to unescape %s while listing", name) } - if IsKeyInDomain([]byte(n), start, end) { - paths = append(paths, n) + key := unescapeKey([]byte(n)) + if IsKeyInDomain(key, start, end, false) { + keys = append(keys, string(key)) } } - return paths, nil + return keys, nil +} + +// To support empty or nil keys, while the file system doesn't allow empty +// filenames. +func escapeKey(key []byte) []byte { + return []byte("k_" + string(key)) +} +func unescapeKey(escKey []byte) []byte { + if len(escKey) < 2 { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + if string(escKey[:2]) != "k_" { + panic(fmt.Sprintf("Invalid esc key: %x", escKey)) + } + return escKey[2:] } diff --git a/db/go_level_db.go b/db/go_level_db.go index 45cb04984..bf2b3bf76 100644 --- a/db/go_level_db.go +++ b/db/go_level_db.go @@ -1,6 +1,7 @@ package db import ( + "bytes" "fmt" "path/filepath" @@ -8,7 +9,6 @@ import ( "github.com/syndtr/goleveldb/leveldb/errors" "github.com/syndtr/goleveldb/leveldb/iterator" "github.com/syndtr/goleveldb/leveldb/opt" - "github.com/syndtr/goleveldb/leveldb/util" . "github.com/tendermint/tmlibs/common" ) @@ -40,33 +40,25 @@ func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { } func (db *GoLevelDB) Get(key []byte) []byte { - panicNilKey(key) + key = nonNilBytes(key) res, err := db.db.Get(key, nil) if err != nil { if err == errors.ErrNotFound { return nil } else { - PanicCrisis(err) + panic(err) } } return res } func (db *GoLevelDB) Has(key []byte) bool { - panicNilKey(key) - _, err := db.db.Get(key, nil) - if err != nil { - if err == errors.ErrNotFound { - return false - } else { - PanicCrisis(err) - } - } - return true + return db.Get(key) != nil } func (db *GoLevelDB) Set(key []byte, value []byte) { - panicNilKey(key) + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(key, value, nil) if err != nil { PanicCrisis(err) @@ -74,7 +66,8 @@ func (db *GoLevelDB) Set(key []byte, value []byte) { } func (db *GoLevelDB) SetSync(key []byte, value []byte) { - panicNilKey(key) + key = nonNilBytes(key) + value = nonNilBytes(value) err := db.db.Put(key, value, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) @@ -82,7 +75,7 @@ func (db *GoLevelDB) SetSync(key []byte, value []byte) { } func (db *GoLevelDB) Delete(key []byte) { - panicNilKey(key) + key = nonNilBytes(key) err := db.db.Delete(key, nil) if err != nil { PanicCrisis(err) @@ -90,7 +83,7 @@ func (db *GoLevelDB) Delete(key []byte) { } func (db *GoLevelDB) DeleteSync(key []byte) { - panicNilKey(key) + key = nonNilBytes(key) err := db.db.Delete(key, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) @@ -169,102 +162,104 @@ func (mBatch *goLevelDBBatch) Write() { //---------------------------------------- // Iterator +// NOTE This is almost identical to db/c_level_db.Iterator +// Before creating a third version, refactor. + +type goLevelDBIterator struct { + source iterator.Iterator + start []byte + end []byte + isReverse bool + isInvalid bool +} + +var _ Iterator = (*goLevelDBIterator)(nil) -// https://godoc.org/github.com/syndtr/goleveldb/leveldb#DB.NewIterator -// A nil Range.Start is treated as a key before all keys in the DB. -// And a nil Range.Limit is treated as a key after all keys in the DB. -func goLevelDBIterRange(start, end []byte) *util.Range { - // XXX: what if start == nil ? - if len(start) == 0 { - start = nil +func newGoLevelDBIterator(source iterator.Iterator, start, end []byte, isReverse bool) *goLevelDBIterator { + if isReverse { + panic("not implemented yet") // XXX } - return &util.Range{ - Start: start, - Limit: end, + source.Seek(start) + return &goLevelDBIterator{ + source: source, + start: start, + end: end, + isReverse: isReverse, + isInvalid: false, } } func (db *GoLevelDB) Iterator(start, end []byte) Iterator { - itrRange := goLevelDBIterRange(start, end) - itr := db.db.NewIterator(itrRange, nil) - itr.Seek(start) // if we don't call this the itr is never valid (?!) - return &goLevelDBIterator{ - source: itr, - start: start, - end: end, - } + itr := db.db.NewIterator(nil, nil) + return newGoLevelDBIterator(itr, start, end, false) } func (db *GoLevelDB) ReverseIterator(start, end []byte) Iterator { - // XXX - return nil + panic("not implemented yet") // XXX } -var _ Iterator = (*goLevelDBIterator)(nil) - -type goLevelDBIterator struct { - source iterator.Iterator - invalid bool - start, end []byte +func (itr *goLevelDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end } -func (it *goLevelDBIterator) Domain() ([]byte, []byte) { - return it.start, it.end -} +func (itr *goLevelDBIterator) Valid() bool { -// Key returns a copy of the current key. -func (it *goLevelDBIterator) Key() []byte { - if !it.Valid() { - panic("goLevelDBIterator Key() called when invalid") + // Once invalid, forever invalid. + if itr.isInvalid { + return false } - key := it.source.Key() - k := make([]byte, len(key)) - copy(k, key) - return k -} + // Panic on DB error. No way to recover. + itr.assertNoError() -// Value returns a copy of the current value. -func (it *goLevelDBIterator) Value() []byte { - if !it.Valid() { - panic("goLevelDBIterator Value() called when invalid") + // If source is invalid, invalid. + if !itr.source.Valid() { + itr.isInvalid = true + return false } - val := it.source.Value() - v := make([]byte, len(val)) - copy(v, val) - - return v -} -func (it *goLevelDBIterator) Valid() bool { - it.assertNoError() - if it.invalid { + // If key is end or past it, invalid. + var end = itr.end + var key = itr.source.Key() + if end != nil && bytes.Compare(end, key) <= 0 { + itr.isInvalid = true return false } - it.invalid = !it.source.Valid() - return !it.invalid + + // Valid + return true } -func (it *goLevelDBIterator) Next() { - if !it.Valid() { - panic("goLevelDBIterator Next() called when invalid") - } - it.source.Next() +func (itr *goLevelDBIterator) Key() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Key() } -func (it *goLevelDBIterator) Prev() { - if !it.Valid() { - panic("goLevelDBIterator Prev() called when invalid") - } - it.source.Prev() +func (itr *goLevelDBIterator) Value() []byte { + itr.assertNoError() + itr.assertIsValid() + return itr.source.Value() +} + +func (itr *goLevelDBIterator) Next() { + itr.assertNoError() + itr.assertIsValid() + itr.source.Next() } -func (it *goLevelDBIterator) Close() { - it.source.Release() +func (itr *goLevelDBIterator) Close() { + itr.source.Release() } -func (it *goLevelDBIterator) assertNoError() { - if err := it.source.Error(); err != nil { +func (itr *goLevelDBIterator) assertNoError() { + if err := itr.source.Error(); err != nil { panic(err) } } + +func (itr goLevelDBIterator) assertIsValid() { + if !itr.Valid() { + panic("goLevelDBIterator is invalid") + } +} diff --git a/db/mem_db.go b/db/mem_db.go index 44254870a..e9d9174dc 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -29,14 +29,16 @@ func NewMemDB() *MemDB { func (db *MemDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + key = nonNilBytes(key) + return db.db[string(key)] } func (db *MemDB) Has(key []byte) bool { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + key = nonNilBytes(key) + _, ok := db.db[string(key)] return ok } @@ -44,43 +46,43 @@ func (db *MemDB) Has(key []byte) bool { func (db *MemDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + db.SetNoLock(key, value) } func (db *MemDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) + db.SetNoLock(key, value) } // NOTE: Implements atomicSetDeleter func (db *MemDB) SetNoLock(key []byte, value []byte) { - if value == nil { - value = []byte{} - } - panicNilKey(key) + key = nonNilBytes(key) + value = nonNilBytes(value) + db.db[string(key)] = value } func (db *MemDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) - delete(db.db, string(key)) + + db.DeleteNoLock(key) } func (db *MemDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - panicNilKey(key) - delete(db.db, string(key)) + + db.DeleteNoLock(key) } // NOTE: Implements atomicSetDeleter func (db *MemDB) DeleteNoLock(key []byte) { - panicNilKey(key) + key = nonNilBytes(key) + delete(db.db, string(key)) } @@ -125,100 +127,92 @@ func (db *MemDB) Mutex() *sync.Mutex { //---------------------------------------- func (db *MemDB) Iterator(start, end []byte) Iterator { - it := newMemDBIterator(db, start, end) - db.mtx.Lock() defer db.mtx.Unlock() - // We need a copy of all of the keys. - // Not the best, but probably not a bottleneck depending. - it.keys = db.getSortedKeys(start, end) - return it + keys := db.getSortedKeys(start, end, false) + return newMemDBIterator(db, keys, start, end) } func (db *MemDB) ReverseIterator(start, end []byte) Iterator { - it := newMemDBIterator(db, start, end) - db.mtx.Lock() defer db.mtx.Unlock() - // We need a copy of all of the keys. - // Not the best, but probably not a bottleneck depending. - it.keys = db.getSortedKeys(end, start) - // reverse the order - l := len(it.keys) - 1 - for i, v := range it.keys { - it.keys[i] = it.keys[l-i] - it.keys[l-i] = v - } - return nil + keys := db.getSortedKeys(end, start, true) + return newMemDBIterator(db, keys, start, end) } -func (db *MemDB) getSortedKeys(start, end []byte) []string { +func (db *MemDB) getSortedKeys(start, end []byte, reverse bool) []string { keys := []string{} for key, _ := range db.db { - if IsKeyInDomain([]byte(key), start, end) { + if IsKeyInDomain([]byte(key), start, end, false) { keys = append(keys, key) } } sort.Strings(keys) + if reverse { + nkeys := len(keys) + for i := 0; i < nkeys/2; i++ { + keys[i] = keys[nkeys-i-1] + } + } return keys } var _ Iterator = (*memDBIterator)(nil) +// We need a copy of all of the keys. +// Not the best, but probably not a bottleneck depending. type memDBIterator struct { - cur int - keys []string - db DB - start, end []byte + db DB + cur int + keys []string + start []byte + end []byte } -func newMemDBIterator(db DB, start, end []byte) *memDBIterator { +// Keys is expected to be in reverse order for reverse iterators. +func newMemDBIterator(db DB, keys []string, start, end []byte) *memDBIterator { return &memDBIterator{ db: db, + cur: 0, + keys: keys, start: start, end: end, } } -func (it *memDBIterator) Domain() ([]byte, []byte) { - return it.start, it.end +func (itr *memDBIterator) Domain() ([]byte, []byte) { + return itr.start, itr.end } -func (it *memDBIterator) Valid() bool { - return 0 <= it.cur && it.cur < len(it.keys) +func (itr *memDBIterator) Valid() bool { + return 0 <= itr.cur && itr.cur < len(itr.keys) } -func (it *memDBIterator) Next() { - if !it.Valid() { - panic("memDBIterator Next() called when invalid") - } - it.cur++ +func (itr *memDBIterator) Next() { + itr.assertIsValid() + itr.cur++ } -func (it *memDBIterator) Prev() { - if !it.Valid() { - panic("memDBIterator Next() called when invalid") - } - it.cur-- +func (itr *memDBIterator) Key() []byte { + itr.assertIsValid() + return []byte(itr.keys[itr.cur]) } -func (it *memDBIterator) Key() []byte { - if !it.Valid() { - panic("memDBIterator Key() called when invalid") - } - return []byte(it.keys[it.cur]) +func (itr *memDBIterator) Value() []byte { + itr.assertIsValid() + key := []byte(itr.keys[itr.cur]) + return itr.db.Get(key) } -func (it *memDBIterator) Value() []byte { - if !it.Valid() { - panic("memDBIterator Value() called when invalid") - } - return it.db.Get(it.Key()) +func (itr *memDBIterator) Close() { + itr.keys = nil + itr.db = nil } -func (it *memDBIterator) Close() { - it.db = nil - it.keys = nil +func (itr *memDBIterator) assertIsValid() { + if !itr.Valid() { + panic("memDBIterator is invalid") + } } diff --git a/db/mem_db_test.go b/db/mem_db_test.go deleted file mode 100644 index a08a3679b..000000000 --- a/db/mem_db_test.go +++ /dev/null @@ -1,48 +0,0 @@ -package db - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestMemDBIterator(t *testing.T) { - db := NewMemDB() - keys := make([][]byte, 100) - for i := 0; i < 100; i++ { - keys[i] = []byte{byte(i)} - } - - value := []byte{5} - for _, k := range keys { - db.Set(k, value) - } - - iter := db.Iterator(BeginningKey(), EndingKey()) - i := 0 - for ; iter.Valid(); iter.Next() { - assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key") - i += 1 - } - assert.Equal(t, i, len(db.db), "iterator didnt cover whole db") -} - -func TestMemDBClose(t *testing.T) { - db := NewMemDB() - copyDB := func(orig map[string][]byte) map[string][]byte { - copy := make(map[string][]byte) - for k, v := range orig { - copy[k] = v - } - return copy - } - k, v := []byte("foo"), []byte("bar") - db.Set(k, v) - require.Equal(t, db.Get(k), v, "expecting a successful get") - copyBefore := copyDB(db.db) - db.Close() - require.Equal(t, db.Get(k), v, "Close is a noop, expecting a successful get") - copyAfter := copyDB(db.db) - require.Equal(t, copyBefore, copyAfter, "Close is a noop and shouldn't modify any internal data") -} diff --git a/db/types.go b/db/types.go index ee8d69cc1..6e5d2408d 100644 --- a/db/types.go +++ b/db/types.go @@ -2,31 +2,39 @@ package db type DB interface { - // Get returns nil iff key doesn't exist. Panics on nil key. + // Get returns nil iff key doesn't exist. + // A nil key is interpreted as an empty byteslice. Get([]byte) []byte - // Has checks if a key exists. Panics on nil key. + // Has checks if a key exists. + // A nil key is interpreted as an empty byteslice. Has(key []byte) bool - // Set sets the key. Panics on nil key. + // Set sets the key. + // A nil key is interpreted as an empty byteslice. Set([]byte, []byte) SetSync([]byte, []byte) - // Delete deletes the key. Panics on nil key. + // Delete deletes the key. + // A nil key is interpreted as an empty byteslice. Delete([]byte) DeleteSync([]byte) - // Iterator over a domain of keys in ascending order. End is exclusive. + // Iterate over a domain of keys in ascending order. End is exclusive. // Start must be less than end, or the Iterator is invalid. + // A nil start is interpreted as an empty byteslice. + // If end is nil, iterates up to the last item (inclusive). // CONTRACT: No writes may happen within a domain while an iterator exists over it. Iterator(start, end []byte) Iterator - // Iterator over a domain of keys in descending order. End is exclusive. + // Iterate over a domain of keys in descending order. End is exclusive. // Start must be greater than end, or the Iterator is invalid. + // If start is nil, iterates from the last/greatest item (inclusive). + // If end is nil, iterates up to the first/least item (iclusive). // CONTRACT: No writes may happen within a domain while an iterator exists over it. ReverseIterator(start, end []byte) Iterator - // Releases the connection. + // Closes the connection. Close() // Creates a batch for atomic updates. @@ -54,16 +62,6 @@ type SetDeleter interface { //---------------------------------------- -// BeginningKey is the smallest key. -func BeginningKey() []byte { - return []byte{} -} - -// EndingKey is the largest key. -func EndingKey() []byte { - return nil -} - /* Usage: @@ -107,7 +105,7 @@ type Iterator interface { // If Valid returns false, this method will panic. Value() []byte - // Release deallocates the given Iterator. + // Close releases the Iterator. Close() } @@ -116,9 +114,12 @@ func bz(s string) []byte { return []byte(s) } -// All DB funcs should panic on nil key. -func panicNilKey(key []byte) { - if key == nil { - panic("nil key") +// We defensively turn nil keys or values into []byte{} for +// most operations. +func nonNilBytes(bz []byte) []byte { + if bz == nil { + return []byte{} + } else { + return bz } } diff --git a/db/util.go b/db/util.go index 661d0a16f..b0ab7f6ad 100644 --- a/db/util.go +++ b/db/util.go @@ -7,8 +7,8 @@ import ( func IteratePrefix(db DB, prefix []byte) Iterator { var start, end []byte if len(prefix) == 0 { - start = BeginningKey() - end = EndingKey() + start = nil + end = nil } else { start = cp(prefix) end = cpIncr(prefix) @@ -35,11 +35,26 @@ func cpIncr(bz []byte) (ret []byte) { ret[i] = byte(0x00) } } - return EndingKey() + return nil } -func IsKeyInDomain(key, start, end []byte) bool { - leftCondition := bytes.Equal(start, BeginningKey()) || bytes.Compare(key, start) >= 0 - rightCondition := bytes.Equal(end, EndingKey()) || bytes.Compare(key, end) < 0 - return leftCondition && rightCondition +// See DB interface documentation for more information. +func IsKeyInDomain(key, start, end []byte, isReverse bool) bool { + if !isReverse { + if bytes.Compare(key, start) < 0 { + return false + } + if end != nil && bytes.Compare(end, key) <= 0 { + return false + } + return true + } else { + if start != nil && bytes.Compare(start, key) < 0 { + return false + } + if end != nil && bytes.Compare(key, end) <= 0 { + return false + } + return true + } } diff --git a/db/util_test.go b/db/util_test.go index b273f8d46..854448af3 100644 --- a/db/util_test.go +++ b/db/util_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -// empty iterator for empty db +// Empty iterator for empty db. func TestPrefixIteratorNoMatchNil(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { @@ -17,7 +17,7 @@ func TestPrefixIteratorNoMatchNil(t *testing.T) { } } -// empty iterator for db populated after iterator created +// Empty iterator for db populated after iterator created. func TestPrefixIteratorNoMatch1(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { @@ -30,7 +30,7 @@ func TestPrefixIteratorNoMatch1(t *testing.T) { } } -// empty iterator for prefix starting above db entry +// Empty iterator for prefix starting after db entry. func TestPrefixIteratorNoMatch2(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { @@ -38,13 +38,12 @@ func TestPrefixIteratorNoMatch2(t *testing.T) { db.SetSync(bz("3"), bz("value_3")) itr := IteratePrefix(db, []byte("4")) - // Once invalid... checkInvalid(t, itr) }) } } -// iterator with single val for db with single val, starting from that val +// Iterator with single val for db with single val, starting from that val. func TestPrefixIteratorMatch1(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { @@ -62,7 +61,7 @@ func TestPrefixIteratorMatch1(t *testing.T) { } } -// iterator with prefix iterates over everything with same prefix +// Iterator with prefix iterates over everything with same prefix. func TestPrefixIteratorMatches1N(t *testing.T) { for backend, _ := range backends { t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) {