diff --git a/db/backend_test.go b/db/backend_test.go index b4ffecdc6..b21ce0037 100644 --- a/db/backend_test.go +++ b/db/backend_test.go @@ -2,42 +2,79 @@ package db import ( "fmt" + "os" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" cmn "github.com/tendermint/tmlibs/common" ) -func testBackend(t *testing.T, backend string) { +func testBackendGetSetDelete(t *testing.T, backend string) { // Default dir, dirname := cmn.Tempdir(fmt.Sprintf("test_backend_%s_", backend)) defer dir.Close() db := NewDB("testdb", backend, dirname) require.Nil(t, db.Get([]byte(""))) - require.Nil(t, db.Get(nil)) // Set empty ("") db.Set([]byte(""), []byte("")) require.NotNil(t, db.Get([]byte(""))) - require.NotNil(t, db.Get(nil)) require.Empty(t, db.Get([]byte(""))) - require.Empty(t, db.Get(nil)) // Set empty (nil) db.Set([]byte(""), nil) require.NotNil(t, db.Get([]byte(""))) - require.NotNil(t, db.Get(nil)) require.Empty(t, db.Get([]byte(""))) - require.Empty(t, db.Get(nil)) // Delete db.Delete([]byte("")) require.Nil(t, db.Get([]byte(""))) - require.Nil(t, db.Get(nil)) } -func TestBackends(t *testing.T) { - testBackend(t, CLevelDBBackendStr) - testBackend(t, GoLevelDBBackendStr) - testBackend(t, MemDBBackendStr) +func TestBackendsGetSetDelete(t *testing.T) { + for dbType, _ := range backends { + if dbType == "fsdb" { + // TODO: handle + // fsdb cant deal with length 0 keys + continue + } + testBackendGetSetDelete(t, dbType) + } +} + +func assertPanics(t *testing.T, dbType, name string, fn func()) { + defer func() { + r := recover() + assert.NotNil(t, r, cmn.Fmt("expecting %s.%s to panic", dbType, name)) + }() + + fn() +} + +func TestBackendsNilKeys(t *testing.T) { + // test all backends + for dbType, creator := range backends { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db, err := creator(name, "") + assert.Nil(t, err) + defer os.RemoveAll(name) + + assertPanics(t, dbType, "get", func() { db.Get(nil) }) + assertPanics(t, dbType, "has", func() { db.Has(nil) }) + assertPanics(t, dbType, "set", func() { db.Set(nil, []byte("abc")) }) + assertPanics(t, dbType, "setsync", func() { db.SetSync(nil, []byte("abc")) }) + assertPanics(t, dbType, "delete", func() { db.Delete(nil) }) + assertPanics(t, dbType, "deletesync", func() { db.DeleteSync(nil) }) + + db.Close() + } +} + +func TestLevelDBBackendStr(t *testing.T) { + name := cmn.Fmt("test_%x", cmn.RandStr(12)) + db := NewDB(name, LevelDBBackendStr, "") + defer os.RemoveAll(name) + _, ok := db.(*GoLevelDB) + assert.True(t, ok) } diff --git a/db/fsdb.go b/db/fsdb.go index b6e08daf5..19ea9fa3c 100644 --- a/db/fsdb.go +++ b/db/fsdb.go @@ -44,6 +44,7 @@ func NewFSDB(dir string) *FSDB { func (db *FSDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) path := db.nameToPath(key) value, err := read(path) @@ -58,6 +59,7 @@ func (db *FSDB) Get(key []byte) []byte { func (db *FSDB) Has(key []byte) bool { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) path := db.nameToPath(key) _, err := read(path) @@ -72,6 +74,7 @@ func (db *FSDB) Has(key []byte) bool { func (db *FSDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) db.SetNoLock(key, value) } @@ -79,12 +82,14 @@ func (db *FSDB) Set(key []byte, value []byte) { func (db *FSDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) db.SetNoLock(key, value) } // NOTE: Implements atomicSetDeleter. func (db *FSDB) SetNoLock(key []byte, value []byte) { + panicNilKey(key) if value == nil { value = []byte{} } @@ -98,6 +103,7 @@ func (db *FSDB) SetNoLock(key []byte, value []byte) { func (db *FSDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) db.DeleteNoLock(key) } @@ -105,12 +111,14 @@ func (db *FSDB) Delete(key []byte) { func (db *FSDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() + panicNilKey(key) db.DeleteNoLock(key) } // NOTE: Implements atomicSetDeleter. func (db *FSDB) DeleteNoLock(key []byte) { + panicNilKey(key) err := remove(string(key)) if os.IsNotExist(err) { return diff --git a/db/go_level_db.go b/db/go_level_db.go index e8ed99dee..201a31949 100644 --- a/db/go_level_db.go +++ b/db/go_level_db.go @@ -37,6 +37,7 @@ func NewGoLevelDB(name string, dir string) (*GoLevelDB, error) { } func (db *GoLevelDB) Get(key []byte) []byte { + panicNilKey(key) res, err := db.db.Get(key, nil) if err != nil { if err == errors.ErrNotFound { @@ -49,6 +50,7 @@ func (db *GoLevelDB) Get(key []byte) []byte { } func (db *GoLevelDB) Has(key []byte) bool { + panicNilKey(key) _, err := db.db.Get(key, nil) if err != nil { if err == errors.ErrNotFound { @@ -61,6 +63,7 @@ func (db *GoLevelDB) Has(key []byte) bool { } func (db *GoLevelDB) Set(key []byte, value []byte) { + panicNilKey(key) err := db.db.Put(key, value, nil) if err != nil { PanicCrisis(err) @@ -68,6 +71,7 @@ func (db *GoLevelDB) Set(key []byte, value []byte) { } func (db *GoLevelDB) SetSync(key []byte, value []byte) { + panicNilKey(key) err := db.db.Put(key, value, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) @@ -75,6 +79,7 @@ func (db *GoLevelDB) SetSync(key []byte, value []byte) { } func (db *GoLevelDB) Delete(key []byte) { + panicNilKey(key) err := db.db.Delete(key, nil) if err != nil { PanicCrisis(err) @@ -82,6 +87,7 @@ func (db *GoLevelDB) Delete(key []byte) { } func (db *GoLevelDB) DeleteSync(key []byte) { + panicNilKey(key) err := db.db.Delete(key, &opt.WriteOptions{Sync: true}) if err != nil { PanicCrisis(err) diff --git a/db/mem_db.go b/db/mem_db.go index 3127030ae..ebeb2dded 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -27,14 +27,14 @@ func NewMemDB() *MemDB { func (db *MemDB) Get(key []byte) []byte { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) return db.db[string(key)] } func (db *MemDB) Has(key []byte) bool { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) _, ok := db.db[string(key)] return ok } @@ -42,14 +42,14 @@ func (db *MemDB) Has(key []byte) bool { func (db *MemDB) Set(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) db.SetNoLock(key, value) } func (db *MemDB) SetSync(key []byte, value []byte) { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) db.SetNoLock(key, value) } @@ -58,25 +58,27 @@ func (db *MemDB) SetNoLock(key []byte, value []byte) { if value == nil { value = []byte{} } + panicNilKey(key) db.db[string(key)] = value } func (db *MemDB) Delete(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) delete(db.db, string(key)) } func (db *MemDB) DeleteSync(key []byte) { db.mtx.Lock() defer db.mtx.Unlock() - + panicNilKey(key) delete(db.db, string(key)) } // NOTE: Implements atomicSetDeleter func (db *MemDB) DeleteNoLock(key []byte) { + panicNilKey(key) delete(db.db, string(key)) } diff --git a/db/types.go b/db/types.go index a6edbdd85..54c1025a0 100644 --- a/db/types.go +++ b/db/types.go @@ -121,3 +121,10 @@ type Iterator interface { func bz(s string) []byte { return []byte(s) } + +// All DB funcs should panic on nil key. +func panicNilKey(key []byte) { + if key == nil { + panic("nil key") + } +}