diff --git a/db/fsdb.go b/db/fsdb.go index 116dc3eef..ac9cdd741 100644 --- a/db/fsdb.go +++ b/db/fsdb.go @@ -7,6 +7,7 @@ import ( "os" "path" "path/filepath" + "sort" "sync" "github.com/pkg/errors" @@ -24,6 +25,8 @@ func init() { }, false) } +var _ DB = (*FSDB)(nil) + // It's slow. type FSDB struct { mtx sync.Mutex @@ -160,26 +163,20 @@ func (db *FSDB) Mutex() *sync.Mutex { } func (db *FSDB) Iterator(start, end []byte) Iterator { - /* - XXX - it := newMemDBIterator() - it.db = db - it.cur = 0 - - db.mtx.Lock() - defer db.mtx.Unlock() - - // We need a copy of all of the keys. - // Not the best, but probably not a bottleneck depending. - keys, err := list(db.dir) - if err != nil { - panic(errors.Wrap(err, fmt.Sprintf("Listing keys in %s", db.dir))) - } - sort.Strings(keys) - it.keys = keys - return it - */ - return nil + it := newMemDBIterator(db, start, end) + + db.mtx.Lock() + defer db.mtx.Unlock() + + // We need a copy of all of the keys. + // Not the best, but probably not a bottleneck depending. + keys, err := list(db.dir, start, end) + if err != nil { + panic(errors.Wrap(err, fmt.Sprintf("Listing keys in %s", db.dir))) + } + sort.Strings(keys) + it.keys = keys + return it } func (db *FSDB) ReverseIterator(start, end []byte) Iterator { @@ -233,7 +230,7 @@ func remove(path string) error { // List files of a path. // Paths will NOT include dir as the prefix. // CONTRACT: returns os errors directly without wrapping. -func list(dirPath string) (paths []string, err error) { +func list(dirPath string, start, end []byte) ([]string, error) { dir, err := os.Open(dirPath) if err != nil { return nil, err @@ -244,12 +241,15 @@ func list(dirPath string) (paths []string, err error) { if err != nil { return nil, err } - for i, name := range names { + var paths []string + for _, name := range names { n, err := url.PathUnescape(name) if err != nil { return nil, fmt.Errorf("Failed to unescape %s while listing", name) } - names[i] = n + if checkKeyCondition(n, start, end) { + paths = append(paths, n) + } } - return names, nil + return paths, nil } diff --git a/db/mem_db.go b/db/mem_db.go index 84d14de98..a9f21d526 100644 --- a/db/mem_db.go +++ b/db/mem_db.go @@ -1,10 +1,8 @@ package db import ( - "bytes" "fmt" "sort" - "strings" "sync" ) @@ -159,9 +157,7 @@ func (db *MemDB) ReverseIterator(start, end []byte) Iterator { func (db *MemDB) getSortedKeys(start, end []byte) []string { keys := []string{} for key, _ := range db.db { - leftCondition := bytes.Equal(start, BeginningKey()) || strings.Compare(key, string(start)) >= 0 - rightCondition := bytes.Equal(end, EndingKey()) || strings.Compare(key, string(end)) < 0 - if leftCondition && rightCondition { + if checkKeyCondition(key, start, end) { keys = append(keys, key) } } diff --git a/db/util.go b/db/util.go index 89c777622..02f4a52f0 100644 --- a/db/util.go +++ b/db/util.go @@ -1,5 +1,10 @@ package db +import ( + "bytes" + "strings" +) + func IteratePrefix(db DB, prefix []byte) Iterator { var start, end []byte if len(prefix) == 0 { @@ -33,3 +38,9 @@ func cpIncr(bz []byte) (ret []byte) { } return EndingKey() } + +func checkKeyCondition(key string, start, end []byte) bool { + leftCondition := bytes.Equal(start, BeginningKey()) || strings.Compare(key, string(start)) >= 0 + rightCondition := bytes.Equal(end, EndingKey()) || strings.Compare(key, string(end)) < 0 + return leftCondition && rightCondition +}