diff --git a/db/mem_db.go b/db/mem_db.go index ebeb2dded..84d14de98 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -3,6 +3,8 @@ package db import ( "bytes" "fmt" + "sort" + "strings" "sync" ) @@ -12,6 +14,8 @@ func init() { }, false) } +var _ DB = (*MemDB)(nil) + type MemDB struct { mtx sync.Mutex db map[string][]byte @@ -123,49 +127,67 @@ func (db *MemDB) Mutex() *sync.Mutex { //---------------------------------------- func (db *MemDB) Iterator(start, end []byte) Iterator { - /* - XXX - it := newMemDBIterator() - it.db = db - it.cur = 0 - - db.mtx.Lock() - defer db.mtx.Unlock() - - // We need a copy of all of the keys. - // Not the best, but probably not a bottleneck depending. - for key, _ := range db.db { - it.keys = append(it.keys, key) - } - sort.Strings(it.keys) - return it - */ - return nil + it := newMemDBIterator(db, start, end) + + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + it.keys = db.getSortedKeys(start, end) + return it } func (db *MemDB) ReverseIterator(start, end []byte) Iterator { - // XXX + it := newMemDBIterator(db, start, end) + + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + it.keys = db.getSortedKeys(end, start) + // reverse the order + l := len(it.keys) - 1 + for i, v := range it.keys { + it.keys[i] = it.keys[l-i] + it.keys[l-i] = v + } return nil } -type memDBIterator struct { - cur int - keys []string - db DB +func (db *MemDB) getSortedKeys(start, end []byte) []string { + keys := []string{} + for key, _ := range db.db { + leftCondition := bytes.Equal(start, BeginningKey()) || strings.Compare(key, string(start)) >= 0 + rightCondition := bytes.Equal(end, EndingKey()) || strings.Compare(key, string(end)) < 0 + if leftCondition && rightCondition { + keys = append(keys, key) + } + } + sort.Strings(keys) + return keys } -func newMemDBIterator() *memDBIterator { - return &memDBIterator{} +var _ Iterator = (*memDBIterator)(nil) + +type memDBIterator struct { + cur int + keys []string + db DB + start, end []byte } -func (it *memDBIterator) Seek(key []byte) { - for i, ik := range it.keys { - it.cur = i - if bytes.Compare(key, []byte(ik)) <= 0 { - return - } +func newMemDBIterator(db DB, start, end []byte) *memDBIterator { + return &memDBIterator{ + db: db, + start: start, + end: end, } - it.cur += 1 // If not found, becomes invalid. +} + +func (it *memDBIterator) Domain() ([]byte, []byte) { + return it.start, it.end } func (it *memDBIterator) Valid() bool { @@ -208,3 +230,5 @@ func (it *memDBIterator) Close() { func (it *memDBIterator) GetError() error { return nil } + +func (it *memDBIterator) Release() {} diff --git a/db/mem_db_test.go b/db/mem_db_test.go index 42e242857..a08a3679b 100644 --- a/db/mem_db_test.go +++ b/db/mem_db_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestMemDbIterator(t *testing.T) { +func TestMemDBIterator(t *testing.T) { db := NewMemDB() keys := make([][]byte, 100) for i := 0; i < 100; i++ {