@ -1,4 +1,4 @@ | |||||
*.swp | |||||
*.sw[opqr] | |||||
vendor | vendor | ||||
.glide | .glide | ||||
@ -1,57 +1,117 @@ | |||||
.PHONY: all test get_vendor_deps ensure_tools | |||||
GOTOOLS = \ | GOTOOLS = \ | ||||
github.com/Masterminds/glide \ | github.com/Masterminds/glide \ | ||||
github.com/alecthomas/gometalinter | |||||
github.com/gogo/protobuf/protoc-gen-gogo \ | |||||
github.com/gogo/protobuf/gogoproto | |||||
# github.com/alecthomas/gometalinter.v2 \ | |||||
PACKAGES=$(shell go list ./... | grep -v '/vendor/') | |||||
REPO:=github.com/tendermint/tmlibs | |||||
GOTOOLS_CHECK = glide gometalinter.v2 protoc protoc-gen-gogo | |||||
INCLUDE = -I=. -I=${GOPATH}/src -I=${GOPATH}/src/github.com/gogo/protobuf/protobuf | |||||
all: test | |||||
all: check get_vendor_deps protoc build test install metalinter | |||||
test: | |||||
@echo "--> Running linter" | |||||
@make metalinter_test | |||||
@echo "--> Running go test" | |||||
@go test $(PACKAGES) | |||||
check: check_tools | |||||
######################################## | |||||
### Build | |||||
protoc: | |||||
## If you get the following error, | |||||
## "error while loading shared libraries: libprotobuf.so.14: cannot open shared object file: No such file or directory" | |||||
## See https://stackoverflow.com/a/25518702 | |||||
protoc $(INCLUDE) --gogo_out=plugins=grpc:. common/*.proto | |||||
@echo "--> adding nolint declarations to protobuf generated files" | |||||
@awk '/package common/ { print "//nolint: gas"; print; next }1' common/types.pb.go > common/types.pb.go.new | |||||
@mv common/types.pb.go.new common/types.pb.go | |||||
build: | |||||
# Nothing to build! | |||||
install: | |||||
# Nothing to install! | |||||
######################################## | |||||
### Tools & dependencies | |||||
check_tools: | |||||
@# https://stackoverflow.com/a/25668869 | |||||
@echo "Found tools: $(foreach tool,$(GOTOOLS_CHECK),\ | |||||
$(if $(shell which $(tool)),$(tool),$(error "No $(tool) in PATH")))" | |||||
get_tools: | |||||
@echo "--> Installing tools" | |||||
go get -u -v $(GOTOOLS) | |||||
# @gometalinter.v2 --install | |||||
get_vendor_deps: ensure_tools | |||||
get_protoc: | |||||
@# https://github.com/google/protobuf/releases | |||||
curl -L https://github.com/google/protobuf/releases/download/v3.4.1/protobuf-cpp-3.4.1.tar.gz | tar xvz && \ | |||||
cd protobuf-3.4.1 && \ | |||||
DIST_LANG=cpp ./configure && \ | |||||
make && \ | |||||
make install && \ | |||||
cd .. && \ | |||||
rm -rf protobuf-3.4.1 | |||||
update_tools: | |||||
@echo "--> Updating tools" | |||||
@go get -u $(GOTOOLS) | |||||
get_vendor_deps: | |||||
@rm -rf vendor/ | @rm -rf vendor/ | ||||
@echo "--> Running glide install" | @echo "--> Running glide install" | ||||
@glide install | @glide install | ||||
ensure_tools: | |||||
go get $(GOTOOLS) | |||||
@gometalinter --install | |||||
######################################## | |||||
### Testing | |||||
test: | |||||
go test -tags gcc `glide novendor` | |||||
metalinter: | |||||
gometalinter --vendor --deadline=600s --enable-all --disable=lll ./... | |||||
######################################## | |||||
### Formatting, linting, and vetting | |||||
metalinter_test: | |||||
gometalinter --vendor --deadline=600s --disable-all \ | |||||
fmt: | |||||
@go fmt ./... | |||||
metalinter: | |||||
@echo "==> Running linter" | |||||
gometalinter.v2 --vendor --deadline=600s --disable-all \ | |||||
--enable=deadcode \ | --enable=deadcode \ | ||||
--enable=goconst \ | --enable=goconst \ | ||||
--enable=goimports \ | |||||
--enable=gosimple \ | --enable=gosimple \ | ||||
--enable=ineffassign \ | |||||
--enable=interfacer \ | |||||
--enable=ineffassign \ | |||||
--enable=megacheck \ | --enable=megacheck \ | ||||
--enable=misspell \ | |||||
--enable=staticcheck \ | |||||
--enable=misspell \ | |||||
--enable=staticcheck \ | |||||
--enable=safesql \ | --enable=safesql \ | ||||
--enable=structcheck \ | |||||
--enable=unconvert \ | |||||
--enable=structcheck \ | |||||
--enable=unconvert \ | |||||
--enable=unused \ | --enable=unused \ | ||||
--enable=varcheck \ | |||||
--enable=varcheck \ | |||||
--enable=vetshadow \ | --enable=vetshadow \ | ||||
--enable=vet \ | |||||
./... | ./... | ||||
#--enable=maligned \ | |||||
#--enable=gas \ | #--enable=gas \ | ||||
#--enable=aligncheck \ | #--enable=aligncheck \ | ||||
#--enable=dupl \ | #--enable=dupl \ | ||||
#--enable=errcheck \ | #--enable=errcheck \ | ||||
#--enable=gocyclo \ | #--enable=gocyclo \ | ||||
#--enable=goimports \ | |||||
#--enable=golint \ <== comments on anything exported | #--enable=golint \ <== comments on anything exported | ||||
#--enable=gotype \ | #--enable=gotype \ | ||||
#--enable=unparam \ | |||||
#--enable=interfacer \ | |||||
#--enable=unparam \ | |||||
#--enable=vet \ | |||||
metalinter_all: | |||||
protoc $(INCLUDE) --lint_out=. types/*.proto | |||||
gometalinter.v2 --vendor --deadline=600s --enable-all --disable=lll ./... | |||||
# To avoid unintended conflicts with file names, always add to .PHONY | |||||
# unless there is a reason not to. | |||||
# https://www.gnu.org/software/make/manual/html_node/Phony-Targets.html | |||||
.PHONY: check protoc build check_tools get_tools get_protoc update_tools get_vendor_deps test fmt metalinter metalinter_all |
@ -0,0 +1,53 @@ | |||||
package common | |||||
import ( | |||||
"encoding/hex" | |||||
"fmt" | |||||
"strings" | |||||
) | |||||
// The main purpose of HexBytes is to enable HEX-encoding for json/encoding. | |||||
type HexBytes []byte | |||||
// Marshal needed for protobuf compatibility | |||||
func (bz HexBytes) Marshal() ([]byte, error) { | |||||
return bz, nil | |||||
} | |||||
// Unmarshal needed for protobuf compatibility | |||||
func (bz *HexBytes) Unmarshal(data []byte) error { | |||||
*bz = data | |||||
return nil | |||||
} | |||||
// This is the point of Bytes. | |||||
func (bz HexBytes) MarshalJSON() ([]byte, error) { | |||||
s := strings.ToUpper(hex.EncodeToString(bz)) | |||||
jbz := make([]byte, len(s)+2) | |||||
jbz[0] = '"' | |||||
copy(jbz[1:], []byte(s)) | |||||
jbz[len(jbz)-1] = '"' | |||||
return jbz, nil | |||||
} | |||||
// This is the point of Bytes. | |||||
func (bz *HexBytes) UnmarshalJSON(data []byte) error { | |||||
if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { | |||||
return fmt.Errorf("Invalid hex string: %s", data) | |||||
} | |||||
bz2, err := hex.DecodeString(string(data[1 : len(data)-1])) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
*bz = bz2 | |||||
return nil | |||||
} | |||||
// Allow it to fulfill various interfaces in light-client, etc... | |||||
func (bz HexBytes) Bytes() []byte { | |||||
return bz | |||||
} | |||||
func (bz HexBytes) String() string { | |||||
return strings.ToUpper(hex.EncodeToString(bz)) | |||||
} |
@ -0,0 +1,65 @@ | |||||
package common | |||||
import ( | |||||
"encoding/json" | |||||
"fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
// This is a trivial test for protobuf compatibility. | |||||
func TestMarshal(t *testing.T) { | |||||
bz := []byte("hello world") | |||||
dataB := HexBytes(bz) | |||||
bz2, err := dataB.Marshal() | |||||
assert.Nil(t, err) | |||||
assert.Equal(t, bz, bz2) | |||||
var dataB2 HexBytes | |||||
err = (&dataB2).Unmarshal(bz) | |||||
assert.Nil(t, err) | |||||
assert.Equal(t, dataB, dataB2) | |||||
} | |||||
// Test that the hex encoding works. | |||||
func TestJSONMarshal(t *testing.T) { | |||||
type TestStruct struct { | |||||
B1 []byte | |||||
B2 HexBytes | |||||
} | |||||
cases := []struct { | |||||
input []byte | |||||
expected string | |||||
}{ | |||||
{[]byte(``), `{"B1":"","B2":""}`}, | |||||
{[]byte(`a`), `{"B1":"YQ==","B2":"61"}`}, | |||||
{[]byte(`abc`), `{"B1":"YWJj","B2":"616263"}`}, | |||||
} | |||||
for i, tc := range cases { | |||||
t.Run(fmt.Sprintf("Case %d", i), func(t *testing.T) { | |||||
ts := TestStruct{B1: tc.input, B2: tc.input} | |||||
// Test that it marshals correctly to JSON. | |||||
jsonBytes, err := json.Marshal(ts) | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
assert.Equal(t, string(jsonBytes), tc.expected) | |||||
// TODO do fuzz testing to ensure that unmarshal fails | |||||
// Test that unmarshaling works correctly. | |||||
ts2 := TestStruct{} | |||||
err = json.Unmarshal(jsonBytes, &ts2) | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
assert.Equal(t, ts2.B1, tc.input) | |||||
assert.Equal(t, ts2.B2, HexBytes(tc.input)) | |||||
}) | |||||
} | |||||
} |
@ -1,153 +0,0 @@ | |||||
package common | |||||
import ( | |||||
"encoding/json" | |||||
"io" | |||||
"net/http" | |||||
"gopkg.in/go-playground/validator.v9" | |||||
"github.com/pkg/errors" | |||||
) | |||||
type ErrorResponse struct { | |||||
Success bool `json:"success,omitempty"` | |||||
// Err is the error message if Success is false | |||||
Err string `json:"error,omitempty"` | |||||
// Code is set if Success is false | |||||
Code int `json:"code,omitempty"` | |||||
} | |||||
// ErrorWithCode makes an ErrorResponse with the | |||||
// provided err's Error() content, and status code. | |||||
// It panics if err is nil. | |||||
func ErrorWithCode(err error, code int) *ErrorResponse { | |||||
return &ErrorResponse{ | |||||
Err: err.Error(), | |||||
Code: code, | |||||
} | |||||
} | |||||
// Ensure that ErrorResponse implements error | |||||
var _ error = (*ErrorResponse)(nil) | |||||
func (er *ErrorResponse) Error() string { | |||||
return er.Err | |||||
} | |||||
// Ensure that ErrorResponse implements httpCoder | |||||
var _ httpCoder = (*ErrorResponse)(nil) | |||||
func (er *ErrorResponse) HTTPCode() int { | |||||
return er.Code | |||||
} | |||||
var errNilBody = errors.Errorf("expecting a non-nil body") | |||||
// FparseJSON unmarshals into save, the body of the provided reader. | |||||
// Since it uses json.Unmarshal, save must be of a pointer type | |||||
// or compatible with json.Unmarshal. | |||||
func FparseJSON(r io.Reader, save interface{}) error { | |||||
if r == nil { | |||||
return errors.Wrap(errNilBody, "Reader") | |||||
} | |||||
dec := json.NewDecoder(r) | |||||
if err := dec.Decode(save); err != nil { | |||||
return errors.Wrap(err, "Decode/Unmarshal") | |||||
} | |||||
return nil | |||||
} | |||||
// ParseRequestJSON unmarshals into save, the body of the | |||||
// request. It closes the body of the request after parsing. | |||||
// Since it uses json.Unmarshal, save must be of a pointer type | |||||
// or compatible with json.Unmarshal. | |||||
func ParseRequestJSON(r *http.Request, save interface{}) error { | |||||
if r == nil || r.Body == nil { | |||||
return errNilBody | |||||
} | |||||
defer r.Body.Close() | |||||
return FparseJSON(r.Body, save) | |||||
} | |||||
// ParseRequestAndValidateJSON unmarshals into save, the body of the | |||||
// request and invokes a validator on the saved content. To ensure | |||||
// validation, make sure to set tags "validate" on your struct as | |||||
// per https://godoc.org/gopkg.in/go-playground/validator.v9. | |||||
// It closes the body of the request after parsing. | |||||
// Since it uses json.Unmarshal, save must be of a pointer type | |||||
// or compatible with json.Unmarshal. | |||||
func ParseRequestAndValidateJSON(r *http.Request, save interface{}) error { | |||||
if r == nil || r.Body == nil { | |||||
return errNilBody | |||||
} | |||||
defer r.Body.Close() | |||||
return FparseAndValidateJSON(r.Body, save) | |||||
} | |||||
// FparseAndValidateJSON like FparseJSON unmarshals into save, | |||||
// the body of the provided reader. However, it invokes the validator | |||||
// to check the set validators on your struct fields as per | |||||
// per https://godoc.org/gopkg.in/go-playground/validator.v9. | |||||
// Since it uses json.Unmarshal, save must be of a pointer type | |||||
// or compatible with json.Unmarshal. | |||||
func FparseAndValidateJSON(r io.Reader, save interface{}) error { | |||||
if err := FparseJSON(r, save); err != nil { | |||||
return err | |||||
} | |||||
return validate(save) | |||||
} | |||||
var theValidator = validator.New() | |||||
func validate(obj interface{}) error { | |||||
return errors.Wrap(theValidator.Struct(obj), "Validate") | |||||
} | |||||
// WriteSuccess JSON marshals the content provided, to an HTTP | |||||
// response, setting the provided status code and setting header | |||||
// "Content-Type" to "application/json". | |||||
func WriteSuccess(w http.ResponseWriter, data interface{}) { | |||||
WriteCode(w, data, 200) | |||||
} | |||||
// WriteCode JSON marshals content, to an HTTP response, | |||||
// setting the provided status code, and setting header | |||||
// "Content-Type" to "application/json". If JSON marshalling fails | |||||
// with an error, WriteCode instead writes out the error invoking | |||||
// WriteError. | |||||
func WriteCode(w http.ResponseWriter, out interface{}, code int) { | |||||
blob, err := json.MarshalIndent(out, "", " ") | |||||
if err != nil { | |||||
WriteError(w, err) | |||||
} else { | |||||
w.Header().Set("Content-Type", "application/json") | |||||
w.WriteHeader(code) | |||||
w.Write(blob) | |||||
} | |||||
} | |||||
type httpCoder interface { | |||||
HTTPCode() int | |||||
} | |||||
// WriteError is a convenience function to write out an | |||||
// error to an http.ResponseWriter, to send out an error | |||||
// that's structured as JSON i.e the form | |||||
// {"error": sss, "code": ddd} | |||||
// If err implements the interface HTTPCode() int, | |||||
// it will use that status code otherwise, it will | |||||
// set code to be http.StatusBadRequest | |||||
func WriteError(w http.ResponseWriter, err error) { | |||||
code := http.StatusBadRequest | |||||
if httpC, ok := err.(httpCoder); ok { | |||||
code = httpC.HTTPCode() | |||||
} | |||||
WriteCode(w, ErrorWithCode(err, code), code) | |||||
} |
@ -1,250 +0,0 @@ | |||||
package common_test | |||||
import ( | |||||
"bytes" | |||||
"encoding/json" | |||||
"errors" | |||||
"io" | |||||
"io/ioutil" | |||||
"net/http" | |||||
"net/http/httptest" | |||||
"reflect" | |||||
"strings" | |||||
"sync" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
"github.com/tendermint/tmlibs/common" | |||||
) | |||||
func TestWriteSuccess(t *testing.T) { | |||||
w := httptest.NewRecorder() | |||||
common.WriteSuccess(w, "foo") | |||||
assert.Equal(t, w.Code, 200, "should get a 200") | |||||
} | |||||
var blankErrResponse = new(common.ErrorResponse) | |||||
func TestWriteError(t *testing.T) { | |||||
tests := [...]struct { | |||||
msg string | |||||
code int | |||||
}{ | |||||
0: { | |||||
msg: "this is a message", | |||||
code: 419, | |||||
}, | |||||
} | |||||
for i, tt := range tests { | |||||
w := httptest.NewRecorder() | |||||
msg := tt.msg | |||||
// First check without a defined code, should send back a 400 | |||||
common.WriteError(w, errors.New(msg)) | |||||
assert.Equal(t, w.Code, http.StatusBadRequest, "#%d: should get a 400", i) | |||||
blob, err := ioutil.ReadAll(w.Body) | |||||
if err != nil { | |||||
assert.Fail(t, "expecting a successful ioutil.ReadAll", "#%d", i) | |||||
continue | |||||
} | |||||
recv := new(common.ErrorResponse) | |||||
if err := json.Unmarshal(blob, recv); err != nil { | |||||
assert.Fail(t, "expecting a successful json.Unmarshal", "#%d", i) | |||||
continue | |||||
} | |||||
assert.Equal(t, reflect.DeepEqual(recv, blankErrResponse), false, "expecting a non-blank error response") | |||||
// Now test with an error that's .HTTPCode() int conforming | |||||
// Reset w | |||||
w = httptest.NewRecorder() | |||||
common.WriteError(w, common.ErrorWithCode(errors.New("foo"), tt.code)) | |||||
assert.Equal(t, w.Code, tt.code, "case #%d", i) | |||||
} | |||||
} | |||||
type marshalFailer struct{} | |||||
var errFooFailed = errors.New("foo failed here") | |||||
func (mf *marshalFailer) MarshalJSON() ([]byte, error) { | |||||
return nil, errFooFailed | |||||
} | |||||
func TestWriteCode(t *testing.T) { | |||||
codes := [...]int{ | |||||
0: http.StatusOK, | |||||
1: http.StatusBadRequest, | |||||
2: http.StatusUnauthorized, | |||||
3: http.StatusInternalServerError, | |||||
} | |||||
for i, code := range codes { | |||||
w := httptest.NewRecorder() | |||||
common.WriteCode(w, "foo", code) | |||||
assert.Equal(t, w.Code, code, "#%d", i) | |||||
// Then for the failed JSON marshaling | |||||
w = httptest.NewRecorder() | |||||
common.WriteCode(w, &marshalFailer{}, code) | |||||
wantCode := http.StatusBadRequest | |||||
assert.Equal(t, w.Code, wantCode, "#%d", i) | |||||
assert.True(t, strings.Contains(w.Body.String(), errFooFailed.Error()), | |||||
"#%d: expected %q in the error message", i, errFooFailed) | |||||
} | |||||
} | |||||
type saver struct { | |||||
Foo int `json:"foo" validate:"min=10"` | |||||
Bar string `json:"bar"` | |||||
} | |||||
type rcloser struct { | |||||
closeOnce sync.Once | |||||
body *bytes.Buffer | |||||
closeChan chan bool | |||||
} | |||||
var errAlreadyClosed = errors.New("already closed") | |||||
func (rc *rcloser) Close() error { | |||||
var err = errAlreadyClosed | |||||
rc.closeOnce.Do(func() { | |||||
err = nil | |||||
rc.closeChan <- true | |||||
close(rc.closeChan) | |||||
}) | |||||
return err | |||||
} | |||||
func (rc *rcloser) Read(b []byte) (int, error) { | |||||
return rc.body.Read(b) | |||||
} | |||||
var _ io.ReadCloser = (*rcloser)(nil) | |||||
func makeReq(strBody string) (*http.Request, <-chan bool) { | |||||
closeChan := make(chan bool, 1) | |||||
buf := new(bytes.Buffer) | |||||
buf.Write([]byte(strBody)) | |||||
req := &http.Request{ | |||||
Header: make(http.Header), | |||||
Body: &rcloser{body: buf, closeChan: closeChan}, | |||||
} | |||||
return req, closeChan | |||||
} | |||||
func TestParseRequestJSON(t *testing.T) { | |||||
tests := [...]struct { | |||||
body string | |||||
wantErr bool | |||||
useNil bool | |||||
}{ | |||||
0: {wantErr: true, body: ``}, | |||||
1: {body: `{}`}, | |||||
2: {body: `{"foo": 2}`}, // Not that the validate tags don't matter here since we are just parsing | |||||
3: {body: `{"foo": "abcd"}`, wantErr: true}, | |||||
4: {useNil: true, wantErr: true}, | |||||
} | |||||
for i, tt := range tests { | |||||
req, closeChan := makeReq(tt.body) | |||||
if tt.useNil { | |||||
req.Body = nil | |||||
} | |||||
sav := new(saver) | |||||
err := common.ParseRequestJSON(req, sav) | |||||
if tt.wantErr { | |||||
assert.NotEqual(t, err, nil, "#%d: want non-nil error", i) | |||||
continue | |||||
} | |||||
assert.Equal(t, err, nil, "#%d: want nil error", i) | |||||
wasClosed := <-closeChan | |||||
assert.Equal(t, wasClosed, true, "#%d: should have invoked close", i) | |||||
} | |||||
} | |||||
func TestFparseJSON(t *testing.T) { | |||||
r1 := strings.NewReader(`{"foo": 1}`) | |||||
sav := new(saver) | |||||
require.Equal(t, common.FparseJSON(r1, sav), nil, "expecting successful parsing") | |||||
r2 := strings.NewReader(`{"bar": "blockchain"}`) | |||||
require.Equal(t, common.FparseJSON(r2, sav), nil, "expecting successful parsing") | |||||
require.Equal(t, reflect.DeepEqual(sav, &saver{Foo: 1, Bar: "blockchain"}), true, "should have parsed both") | |||||
// Now with a nil body | |||||
require.NotEqual(t, nil, common.FparseJSON(nil, sav), "expecting a nil error report") | |||||
} | |||||
func TestFparseAndValidateJSON(t *testing.T) { | |||||
r1 := strings.NewReader(`{"foo": 1}`) | |||||
sav := new(saver) | |||||
require.NotEqual(t, common.FparseAndValidateJSON(r1, sav), nil, "expecting validation to fail") | |||||
r1 = strings.NewReader(`{"foo": 100}`) | |||||
require.Equal(t, common.FparseJSON(r1, sav), nil, "expecting successful parsing") | |||||
r2 := strings.NewReader(`{"bar": "blockchain"}`) | |||||
require.Equal(t, common.FparseAndValidateJSON(r2, sav), nil, "expecting successful parsing") | |||||
require.Equal(t, reflect.DeepEqual(sav, &saver{Foo: 100, Bar: "blockchain"}), true, "should have parsed both") | |||||
// Now with a nil body | |||||
require.NotEqual(t, nil, common.FparseJSON(nil, sav), "expecting a nil error report") | |||||
} | |||||
var blankSaver = new(saver) | |||||
func TestParseAndValidateRequestJSON(t *testing.T) { | |||||
tests := [...]struct { | |||||
body string | |||||
wantErr bool | |||||
useNil bool | |||||
}{ | |||||
0: {wantErr: true, body: ``}, | |||||
1: {body: `{}`, wantErr: true}, // Here it should fail since Foo doesn't meet the minimum value | |||||
2: {body: `{"foo": 2}`, wantErr: true}, // Here validation should fail | |||||
3: {body: `{"foo": "abcd"}`, wantErr: true}, | |||||
4: {useNil: true, wantErr: true}, | |||||
5: {body: `{"foo": 100}`}, // Must succeed | |||||
} | |||||
for i, tt := range tests { | |||||
req, closeChan := makeReq(tt.body) | |||||
if tt.useNil { | |||||
req.Body = nil | |||||
} | |||||
sav := new(saver) | |||||
err := common.ParseRequestAndValidateJSON(req, sav) | |||||
if tt.wantErr { | |||||
assert.NotEqual(t, err, nil, "#%d: want non-nil error", i) | |||||
continue | |||||
} | |||||
assert.Equal(t, err, nil, "#%d: want nil error", i) | |||||
assert.False(t, reflect.DeepEqual(blankSaver, sav), "#%d: expecting a set saver", i) | |||||
wasClosed := <-closeChan | |||||
assert.Equal(t, wasClosed, true, "#%d: should have invoked close", i) | |||||
} | |||||
} | |||||
func TestErrorWithCode(t *testing.T) { | |||||
tests := [...]struct { | |||||
code int | |||||
err error | |||||
}{ | |||||
0: {code: 500, err: errors.New("funky")}, | |||||
1: {code: 406, err: errors.New("purist")}, | |||||
} | |||||
for i, tt := range tests { | |||||
errRes := common.ErrorWithCode(tt.err, tt.code) | |||||
assert.Equal(t, errRes.Error(), tt.err.Error(), "#%d: expecting the error values to be equal", i) | |||||
assert.Equal(t, errRes.Code, tt.code, "expecting the same status code", i) | |||||
assert.Equal(t, errRes.HTTPCode(), tt.code, "expecting the same status code", i) | |||||
} | |||||
} |
@ -0,0 +1,67 @@ | |||||
package common | |||||
import ( | |||||
"bytes" | |||||
"sort" | |||||
) | |||||
//---------------------------------------- | |||||
// KVPair | |||||
/* | |||||
Defined in types.proto | |||||
type KVPair struct { | |||||
Key []byte | |||||
Value []byte | |||||
} | |||||
*/ | |||||
type KVPairs []KVPair | |||||
// Sorting | |||||
func (kvs KVPairs) Len() int { return len(kvs) } | |||||
func (kvs KVPairs) Less(i, j int) bool { | |||||
switch bytes.Compare(kvs[i].Key, kvs[j].Key) { | |||||
case -1: | |||||
return true | |||||
case 0: | |||||
return bytes.Compare(kvs[i].Value, kvs[j].Value) < 0 | |||||
case 1: | |||||
return false | |||||
default: | |||||
panic("invalid comparison result") | |||||
} | |||||
} | |||||
func (kvs KVPairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } | |||||
func (kvs KVPairs) Sort() { sort.Sort(kvs) } | |||||
//---------------------------------------- | |||||
// KI64Pair | |||||
/* | |||||
Defined in types.proto | |||||
type KI64Pair struct { | |||||
Key []byte | |||||
Value int64 | |||||
} | |||||
*/ | |||||
type KI64Pairs []KI64Pair | |||||
// Sorting | |||||
func (kvs KI64Pairs) Len() int { return len(kvs) } | |||||
func (kvs KI64Pairs) Less(i, j int) bool { | |||||
switch bytes.Compare(kvs[i].Key, kvs[j].Key) { | |||||
case -1: | |||||
return true | |||||
case 0: | |||||
return kvs[i].Value < kvs[j].Value | |||||
case 1: | |||||
return false | |||||
default: | |||||
panic("invalid comparison result") | |||||
} | |||||
} | |||||
func (kvs KI64Pairs) Swap(i, j int) { kvs[i], kvs[j] = kvs[j], kvs[i] } | |||||
func (kvs KI64Pairs) Sort() { sort.Sort(kvs) } |
@ -0,0 +1,101 @@ | |||||
// Code generated by protoc-gen-gogo. DO NOT EDIT. | |||||
// source: common/types.proto | |||||
/* | |||||
Package common is a generated protocol buffer package. | |||||
It is generated from these files: | |||||
common/types.proto | |||||
It has these top-level messages: | |||||
KVPair | |||||
KI64Pair | |||||
*/ | |||||
//nolint: gas | |||||
package common | |||||
import proto "github.com/gogo/protobuf/proto" | |||||
import fmt "fmt" | |||||
import math "math" | |||||
import _ "github.com/gogo/protobuf/gogoproto" | |||||
// Reference imports to suppress errors if they are not otherwise used. | |||||
var _ = proto.Marshal | |||||
var _ = fmt.Errorf | |||||
var _ = math.Inf | |||||
// This is a compile-time assertion to ensure that this generated file | |||||
// is compatible with the proto package it is being compiled against. | |||||
// A compilation error at this line likely means your copy of the | |||||
// proto package needs to be updated. | |||||
const _ = proto.GoGoProtoPackageIsVersion2 // please upgrade the proto package | |||||
// Define these here for compatibility but use tmlibs/common.KVPair. | |||||
type KVPair struct { | |||||
Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` | |||||
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` | |||||
} | |||||
func (m *KVPair) Reset() { *m = KVPair{} } | |||||
func (m *KVPair) String() string { return proto.CompactTextString(m) } | |||||
func (*KVPair) ProtoMessage() {} | |||||
func (*KVPair) Descriptor() ([]byte, []int) { return fileDescriptorTypes, []int{0} } | |||||
func (m *KVPair) GetKey() []byte { | |||||
if m != nil { | |||||
return m.Key | |||||
} | |||||
return nil | |||||
} | |||||
func (m *KVPair) GetValue() []byte { | |||||
if m != nil { | |||||
return m.Value | |||||
} | |||||
return nil | |||||
} | |||||
// Define these here for compatibility but use tmlibs/common.KI64Pair. | |||||
type KI64Pair struct { | |||||
Key []byte `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` | |||||
Value int64 `protobuf:"varint,2,opt,name=value,proto3" json:"value,omitempty"` | |||||
} | |||||
func (m *KI64Pair) Reset() { *m = KI64Pair{} } | |||||
func (m *KI64Pair) String() string { return proto.CompactTextString(m) } | |||||
func (*KI64Pair) ProtoMessage() {} | |||||
func (*KI64Pair) Descriptor() ([]byte, []int) { return fileDescriptorTypes, []int{1} } | |||||
func (m *KI64Pair) GetKey() []byte { | |||||
if m != nil { | |||||
return m.Key | |||||
} | |||||
return nil | |||||
} | |||||
func (m *KI64Pair) GetValue() int64 { | |||||
if m != nil { | |||||
return m.Value | |||||
} | |||||
return 0 | |||||
} | |||||
func init() { | |||||
proto.RegisterType((*KVPair)(nil), "common.KVPair") | |||||
proto.RegisterType((*KI64Pair)(nil), "common.KI64Pair") | |||||
} | |||||
func init() { proto.RegisterFile("common/types.proto", fileDescriptorTypes) } | |||||
var fileDescriptorTypes = []byte{ | |||||
// 137 bytes of a gzipped FileDescriptorProto | |||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x4a, 0xce, 0xcf, 0xcd, | |||||
0xcd, 0xcf, 0xd3, 0x2f, 0xa9, 0x2c, 0x48, 0x2d, 0xd6, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, | |||||
0x83, 0x88, 0x49, 0xe9, 0xa6, 0x67, 0x96, 0x64, 0x94, 0x26, 0xe9, 0x25, 0xe7, 0xe7, 0xea, 0xa7, | |||||
0xe7, 0xa7, 0xe7, 0xeb, 0x83, 0xa5, 0x93, 0x4a, 0xd3, 0xc0, 0x3c, 0x30, 0x07, 0xcc, 0x82, 0x68, | |||||
0x53, 0x32, 0xe0, 0x62, 0xf3, 0x0e, 0x0b, 0x48, 0xcc, 0x2c, 0x12, 0x12, 0xe0, 0x62, 0xce, 0x4e, | |||||
0xad, 0x94, 0x60, 0x54, 0x60, 0xd4, 0xe0, 0x09, 0x02, 0x31, 0x85, 0x44, 0xb8, 0x58, 0xcb, 0x12, | |||||
0x73, 0x4a, 0x53, 0x25, 0x98, 0xc0, 0x62, 0x10, 0x8e, 0x92, 0x11, 0x17, 0x87, 0xb7, 0xa7, 0x99, | |||||
0x09, 0x31, 0x7a, 0x98, 0xa1, 0x7a, 0x92, 0xd8, 0xc0, 0x96, 0x19, 0x03, 0x02, 0x00, 0x00, 0xff, | |||||
0xff, 0x5c, 0xb8, 0x46, 0xc5, 0xb9, 0x00, 0x00, 0x00, | |||||
} |
@ -0,0 +1,24 @@ | |||||
syntax = "proto3"; | |||||
package common; | |||||
// For more information on gogo.proto, see: | |||||
// https://github.com/gogo/protobuf/blob/master/extensions.md | |||||
// NOTE: Try really hard not to use custom types, | |||||
// it's often complicated, broken, nor not worth it. | |||||
import "github.com/gogo/protobuf/gogoproto/gogo.proto"; | |||||
//---------------------------------------- | |||||
// Abstract types | |||||
// Define these here for compatibility but use tmlibs/common.KVPair. | |||||
message KVPair { | |||||
bytes key = 1; | |||||
bytes value = 2; | |||||
} | |||||
// Define these here for compatibility but use tmlibs/common.KI64Pair. | |||||
message KI64Pair { | |||||
bytes key = 1; | |||||
int64 value = 2; | |||||
} |
@ -0,0 +1,151 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"os" | |||||
"path/filepath" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
) | |||||
func cleanupDBDir(dir, name string) { | |||||
os.RemoveAll(filepath.Join(dir, name) + ".db") | |||||
} | |||||
func testBackendGetSetDelete(t *testing.T, backend string) { | |||||
// Default | |||||
dir, dirname := cmn.Tempdir(fmt.Sprintf("test_backend_%s_", backend)) | |||||
defer dir.Close() | |||||
db := NewDB("testdb", backend, dirname) | |||||
// A nonexistent key should return nil, even if the key is empty | |||||
require.Nil(t, db.Get([]byte(""))) | |||||
// A nonexistent key should return nil, even if the key is nil | |||||
require.Nil(t, db.Get(nil)) | |||||
// A nonexistent key should return nil. | |||||
key := []byte("abc") | |||||
require.Nil(t, db.Get(key)) | |||||
// Set empty value. | |||||
db.Set(key, []byte("")) | |||||
require.NotNil(t, db.Get(key)) | |||||
require.Empty(t, db.Get(key)) | |||||
// Set nil value. | |||||
db.Set(key, nil) | |||||
require.NotNil(t, db.Get(key)) | |||||
require.Empty(t, db.Get(key)) | |||||
// Delete. | |||||
db.Delete(key) | |||||
require.Nil(t, db.Get(key)) | |||||
} | |||||
func TestBackendsGetSetDelete(t *testing.T) { | |||||
for dbType, _ := range backends { | |||||
testBackendGetSetDelete(t, dbType) | |||||
} | |||||
} | |||||
func withDB(t *testing.T, creator dbCreator, fn func(DB)) { | |||||
name := cmn.Fmt("test_%x", cmn.RandStr(12)) | |||||
db, err := creator(name, "") | |||||
defer cleanupDBDir("", name) | |||||
assert.Nil(t, err) | |||||
fn(db) | |||||
db.Close() | |||||
} | |||||
func TestBackendsNilKeys(t *testing.T) { | |||||
// Test all backends. | |||||
for dbType, creator := range backends { | |||||
withDB(t, creator, func(db DB) { | |||||
t.Run(fmt.Sprintf("Testing %s", dbType), func(t *testing.T) { | |||||
// Nil keys are treated as the empty key for most operations. | |||||
expect := func(key, value []byte) { | |||||
if len(key) == 0 { // nil or empty | |||||
assert.Equal(t, db.Get(nil), db.Get([]byte(""))) | |||||
assert.Equal(t, db.Has(nil), db.Has([]byte(""))) | |||||
} | |||||
assert.Equal(t, db.Get(key), value) | |||||
assert.Equal(t, db.Has(key), value != nil) | |||||
} | |||||
// Not set | |||||
expect(nil, nil) | |||||
// Set nil value | |||||
db.Set(nil, nil) | |||||
expect(nil, []byte("")) | |||||
// Set empty value | |||||
db.Set(nil, []byte("")) | |||||
expect(nil, []byte("")) | |||||
// Set nil, Delete nil | |||||
db.Set(nil, []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.Delete(nil) | |||||
expect(nil, nil) | |||||
// Set nil, Delete empty | |||||
db.Set(nil, []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.Delete([]byte("")) | |||||
expect(nil, nil) | |||||
// Set empty, Delete nil | |||||
db.Set([]byte(""), []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.Delete(nil) | |||||
expect(nil, nil) | |||||
// Set empty, Delete empty | |||||
db.Set([]byte(""), []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.Delete([]byte("")) | |||||
expect(nil, nil) | |||||
// SetSync nil, DeleteSync nil | |||||
db.SetSync(nil, []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.DeleteSync(nil) | |||||
expect(nil, nil) | |||||
// SetSync nil, DeleteSync empty | |||||
db.SetSync(nil, []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.DeleteSync([]byte("")) | |||||
expect(nil, nil) | |||||
// SetSync empty, DeleteSync nil | |||||
db.SetSync([]byte(""), []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.DeleteSync(nil) | |||||
expect(nil, nil) | |||||
// SetSync empty, DeleteSync empty | |||||
db.SetSync([]byte(""), []byte("abc")) | |||||
expect(nil, []byte("abc")) | |||||
db.DeleteSync([]byte("")) | |||||
expect(nil, nil) | |||||
}) | |||||
}) | |||||
} | |||||
} | |||||
func TestGoLevelDBBackendStr(t *testing.T) { | |||||
name := cmn.Fmt("test_%x", cmn.RandStr(12)) | |||||
db := NewDB(name, GoLevelDBBackendStr, "") | |||||
defer cleanupDBDir("", name) | |||||
_, ok := db.(*GoLevelDB) | |||||
assert.True(t, ok) | |||||
} |
@ -0,0 +1,155 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
) | |||||
func checkValid(t *testing.T, itr Iterator, expected bool) { | |||||
valid := itr.Valid() | |||||
require.Equal(t, expected, valid) | |||||
} | |||||
func checkNext(t *testing.T, itr Iterator, expected bool) { | |||||
itr.Next() | |||||
valid := itr.Valid() | |||||
require.Equal(t, expected, valid) | |||||
} | |||||
func checkNextPanics(t *testing.T, itr Iterator) { | |||||
assert.Panics(t, func() { itr.Next() }, "checkNextPanics expected panic but didn't") | |||||
} | |||||
func checkItem(t *testing.T, itr Iterator, key []byte, value []byte) { | |||||
k, v := itr.Key(), itr.Value() | |||||
assert.Exactly(t, key, k) | |||||
assert.Exactly(t, value, v) | |||||
} | |||||
func checkInvalid(t *testing.T, itr Iterator) { | |||||
checkValid(t, itr, false) | |||||
checkKeyPanics(t, itr) | |||||
checkValuePanics(t, itr) | |||||
checkNextPanics(t, itr) | |||||
} | |||||
func checkKeyPanics(t *testing.T, itr Iterator) { | |||||
assert.Panics(t, func() { itr.Key() }, "checkKeyPanics expected panic but didn't") | |||||
} | |||||
func checkValuePanics(t *testing.T, itr Iterator) { | |||||
assert.Panics(t, func() { itr.Key() }, "checkValuePanics expected panic but didn't") | |||||
} | |||||
func newTempDB(t *testing.T, backend string) (db DB) { | |||||
dir, dirname := cmn.Tempdir("test_go_iterator") | |||||
db = NewDB("testdb", backend, dirname) | |||||
dir.Close() | |||||
return db | |||||
} | |||||
func TestDBIteratorSingleKey(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
itr := db.Iterator(nil, nil) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, false) | |||||
checkValid(t, itr, false) | |||||
checkNextPanics(t, itr) | |||||
// Once invalid... | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorTwoKeys(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
db.SetSync(bz("2"), bz("value_1")) | |||||
{ // Fail by calling Next too much | |||||
itr := db.Iterator(nil, nil) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, true) | |||||
checkValid(t, itr, true) | |||||
checkNext(t, itr, false) | |||||
checkValid(t, itr, false) | |||||
checkNextPanics(t, itr) | |||||
// Once invalid... | |||||
checkInvalid(t, itr) | |||||
} | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorMany(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
keys := make([][]byte, 100) | |||||
for i := 0; i < 100; i++ { | |||||
keys[i] = []byte{byte(i)} | |||||
} | |||||
value := []byte{5} | |||||
for _, k := range keys { | |||||
db.Set(k, value) | |||||
} | |||||
itr := db.Iterator(nil, nil) | |||||
defer itr.Close() | |||||
for ; itr.Valid(); itr.Next() { | |||||
assert.Equal(t, db.Get(itr.Key()), itr.Value()) | |||||
} | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorEmpty(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := db.Iterator(nil, nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorEmptyBeginAfter(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := db.Iterator(bz("1"), nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
func TestDBIteratorNonemptyBeginAfter(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
itr := db.Iterator(bz("2"), nil) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} |
@ -0,0 +1,254 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"io/ioutil" | |||||
"net/url" | |||||
"os" | |||||
"path/filepath" | |||||
"sort" | |||||
"sync" | |||||
"github.com/pkg/errors" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
) | |||||
const ( | |||||
keyPerm = os.FileMode(0600) | |||||
dirPerm = os.FileMode(0700) | |||||
) | |||||
func init() { | |||||
registerDBCreator(FSDBBackendStr, func(name string, dir string) (DB, error) { | |||||
dbPath := filepath.Join(dir, name+".db") | |||||
return NewFSDB(dbPath), nil | |||||
}, false) | |||||
} | |||||
var _ DB = (*FSDB)(nil) | |||||
// It's slow. | |||||
type FSDB struct { | |||||
mtx sync.Mutex | |||||
dir string | |||||
} | |||||
func NewFSDB(dir string) *FSDB { | |||||
err := os.MkdirAll(dir, dirPerm) | |||||
if err != nil { | |||||
panic(errors.Wrap(err, "Creating FSDB dir "+dir)) | |||||
} | |||||
database := &FSDB{ | |||||
dir: dir, | |||||
} | |||||
return database | |||||
} | |||||
func (db *FSDB) Get(key []byte) []byte { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
key = escapeKey(key) | |||||
path := db.nameToPath(key) | |||||
value, err := read(path) | |||||
if os.IsNotExist(err) { | |||||
return nil | |||||
} else if err != nil { | |||||
panic(errors.Wrapf(err, "Getting key %s (0x%X)", string(key), key)) | |||||
} | |||||
return value | |||||
} | |||||
func (db *FSDB) Has(key []byte) bool { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
key = escapeKey(key) | |||||
path := db.nameToPath(key) | |||||
return cmn.FileExists(path) | |||||
} | |||||
func (db *FSDB) Set(key []byte, value []byte) { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
db.SetNoLock(key, value) | |||||
} | |||||
func (db *FSDB) SetSync(key []byte, value []byte) { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
db.SetNoLock(key, value) | |||||
} | |||||
// NOTE: Implements atomicSetDeleter. | |||||
func (db *FSDB) SetNoLock(key []byte, value []byte) { | |||||
key = escapeKey(key) | |||||
value = nonNilBytes(value) | |||||
path := db.nameToPath(key) | |||||
err := write(path, value) | |||||
if err != nil { | |||||
panic(errors.Wrapf(err, "Setting key %s (0x%X)", string(key), key)) | |||||
} | |||||
} | |||||
func (db *FSDB) Delete(key []byte) { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
db.DeleteNoLock(key) | |||||
} | |||||
func (db *FSDB) DeleteSync(key []byte) { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
db.DeleteNoLock(key) | |||||
} | |||||
// NOTE: Implements atomicSetDeleter. | |||||
func (db *FSDB) DeleteNoLock(key []byte) { | |||||
key = escapeKey(key) | |||||
path := db.nameToPath(key) | |||||
err := remove(path) | |||||
if os.IsNotExist(err) { | |||||
return | |||||
} else if err != nil { | |||||
panic(errors.Wrapf(err, "Removing key %s (0x%X)", string(key), key)) | |||||
} | |||||
} | |||||
func (db *FSDB) Close() { | |||||
// Nothing to do. | |||||
} | |||||
func (db *FSDB) Print() { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
panic("FSDB.Print not yet implemented") | |||||
} | |||||
func (db *FSDB) Stats() map[string]string { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
panic("FSDB.Stats not yet implemented") | |||||
} | |||||
func (db *FSDB) NewBatch() Batch { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
// Not sure we would ever want to try... | |||||
// It doesn't seem easy for general filesystems. | |||||
panic("FSDB.NewBatch not yet implemented") | |||||
} | |||||
func (db *FSDB) Mutex() *sync.Mutex { | |||||
return &(db.mtx) | |||||
} | |||||
func (db *FSDB) Iterator(start, end []byte) Iterator { | |||||
db.mtx.Lock() | |||||
defer db.mtx.Unlock() | |||||
// We need a copy of all of the keys. | |||||
// Not the best, but probably not a bottleneck depending. | |||||
keys, err := list(db.dir, start, end) | |||||
if err != nil { | |||||
panic(errors.Wrapf(err, "Listing keys in %s", db.dir)) | |||||
} | |||||
sort.Strings(keys) | |||||
return newMemDBIterator(db, keys, start, end) | |||||
} | |||||
func (db *FSDB) ReverseIterator(start, end []byte) Iterator { | |||||
panic("not implemented yet") // XXX | |||||
} | |||||
func (db *FSDB) nameToPath(name []byte) string { | |||||
n := url.PathEscape(string(name)) | |||||
return filepath.Join(db.dir, n) | |||||
} | |||||
// Read some bytes to a file. | |||||
// CONTRACT: returns os errors directly without wrapping. | |||||
func read(path string) ([]byte, error) { | |||||
f, err := os.Open(path) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
defer f.Close() | |||||
d, err := ioutil.ReadAll(f) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return d, nil | |||||
} | |||||
// Write some bytes from a file. | |||||
// CONTRACT: returns os errors directly without wrapping. | |||||
func write(path string, d []byte) error { | |||||
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, keyPerm) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
defer f.Close() | |||||
_, err = f.Write(d) | |||||
if err != nil { | |||||
return err | |||||
} | |||||
err = f.Sync() | |||||
return err | |||||
} | |||||
// Remove a file. | |||||
// CONTRACT: returns os errors directly without wrapping. | |||||
func remove(path string) error { | |||||
return os.Remove(path) | |||||
} | |||||
// List keys in a directory, stripping of escape sequences and dir portions. | |||||
// CONTRACT: returns os errors directly without wrapping. | |||||
func list(dirPath string, start, end []byte) ([]string, error) { | |||||
dir, err := os.Open(dirPath) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
defer dir.Close() | |||||
names, err := dir.Readdirnames(0) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
var keys []string | |||||
for _, name := range names { | |||||
n, err := url.PathUnescape(name) | |||||
if err != nil { | |||||
return nil, fmt.Errorf("Failed to unescape %s while listing", name) | |||||
} | |||||
key := unescapeKey([]byte(n)) | |||||
if IsKeyInDomain(key, start, end, false) { | |||||
keys = append(keys, string(key)) | |||||
} | |||||
} | |||||
return keys, nil | |||||
} | |||||
// To support empty or nil keys, while the file system doesn't allow empty | |||||
// filenames. | |||||
func escapeKey(key []byte) []byte { | |||||
return []byte("k_" + string(key)) | |||||
} | |||||
func unescapeKey(escKey []byte) []byte { | |||||
if len(escKey) < 2 { | |||||
panic(fmt.Sprintf("Invalid esc key: %x", escKey)) | |||||
} | |||||
if string(escKey[:2]) != "k_" { | |||||
panic(fmt.Sprintf("Invalid esc key: %x", escKey)) | |||||
} | |||||
return escKey[2:] | |||||
} |
@ -0,0 +1,50 @@ | |||||
package db | |||||
import "sync" | |||||
type atomicSetDeleter interface { | |||||
Mutex() *sync.Mutex | |||||
SetNoLock(key, value []byte) | |||||
DeleteNoLock(key []byte) | |||||
} | |||||
type memBatch struct { | |||||
db atomicSetDeleter | |||||
ops []operation | |||||
} | |||||
type opType int | |||||
const ( | |||||
opTypeSet opType = 1 | |||||
opTypeDelete opType = 2 | |||||
) | |||||
type operation struct { | |||||
opType | |||||
key []byte | |||||
value []byte | |||||
} | |||||
func (mBatch *memBatch) Set(key, value []byte) { | |||||
mBatch.ops = append(mBatch.ops, operation{opTypeSet, key, value}) | |||||
} | |||||
func (mBatch *memBatch) Delete(key []byte) { | |||||
mBatch.ops = append(mBatch.ops, operation{opTypeDelete, key, nil}) | |||||
} | |||||
func (mBatch *memBatch) Write() { | |||||
mtx := mBatch.db.Mutex() | |||||
mtx.Lock() | |||||
defer mtx.Unlock() | |||||
for _, op := range mBatch.ops { | |||||
switch op.opType { | |||||
case opTypeSet: | |||||
mBatch.db.SetNoLock(op.key, op.value) | |||||
case opTypeDelete: | |||||
mBatch.db.DeleteNoLock(op.key) | |||||
} | |||||
} | |||||
} |
@ -1,48 +0,0 @@ | |||||
package db | |||||
import ( | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
"github.com/stretchr/testify/require" | |||||
) | |||||
func TestMemDbIterator(t *testing.T) { | |||||
db := NewMemDB() | |||||
keys := make([][]byte, 100) | |||||
for i := 0; i < 100; i++ { | |||||
keys[i] = []byte{byte(i)} | |||||
} | |||||
value := []byte{5} | |||||
for _, k := range keys { | |||||
db.Set(k, value) | |||||
} | |||||
iter := db.Iterator() | |||||
i := 0 | |||||
for iter.Next() { | |||||
assert.Equal(t, db.Get(iter.Key()), iter.Value(), "values dont match for key") | |||||
i += 1 | |||||
} | |||||
assert.Equal(t, i, len(db.db), "iterator didnt cover whole db") | |||||
} | |||||
func TestMemDBClose(t *testing.T) { | |||||
db := NewMemDB() | |||||
copyDB := func(orig map[string][]byte) map[string][]byte { | |||||
copy := make(map[string][]byte) | |||||
for k, v := range orig { | |||||
copy[k] = v | |||||
} | |||||
return copy | |||||
} | |||||
k, v := []byte("foo"), []byte("bar") | |||||
db.Set(k, v) | |||||
require.Equal(t, db.Get(k), v, "expecting a successful get") | |||||
copyBefore := copyDB(db.db) | |||||
db.Close() | |||||
require.Equal(t, db.Get(k), v, "Close is a noop, expecting a successful get") | |||||
copyAfter := copyDB(db.db) | |||||
require.Equal(t, copyBefore, copyAfter, "Close is a noop and shouldn't modify any internal data") | |||||
} |
@ -0,0 +1,133 @@ | |||||
package db | |||||
type DB interface { | |||||
// Get returns nil iff key doesn't exist. | |||||
// A nil key is interpreted as an empty byteslice. | |||||
// CONTRACT: key, value readonly []byte | |||||
Get([]byte) []byte | |||||
// Has checks if a key exists. | |||||
// A nil key is interpreted as an empty byteslice. | |||||
// CONTRACT: key, value readonly []byte | |||||
Has(key []byte) bool | |||||
// Set sets the key. | |||||
// A nil key is interpreted as an empty byteslice. | |||||
// CONTRACT: key, value readonly []byte | |||||
Set([]byte, []byte) | |||||
SetSync([]byte, []byte) | |||||
// Delete deletes the key. | |||||
// A nil key is interpreted as an empty byteslice. | |||||
// CONTRACT: key readonly []byte | |||||
Delete([]byte) | |||||
DeleteSync([]byte) | |||||
// Iterate over a domain of keys in ascending order. End is exclusive. | |||||
// Start must be less than end, or the Iterator is invalid. | |||||
// A nil start is interpreted as an empty byteslice. | |||||
// If end is nil, iterates up to the last item (inclusive). | |||||
// CONTRACT: No writes may happen within a domain while an iterator exists over it. | |||||
// CONTRACT: start, end readonly []byte | |||||
Iterator(start, end []byte) Iterator | |||||
// Iterate over a domain of keys in descending order. End is exclusive. | |||||
// Start must be greater than end, or the Iterator is invalid. | |||||
// If start is nil, iterates from the last/greatest item (inclusive). | |||||
// If end is nil, iterates up to the first/least item (iclusive). | |||||
// CONTRACT: No writes may happen within a domain while an iterator exists over it. | |||||
// CONTRACT: start, end readonly []byte | |||||
ReverseIterator(start, end []byte) Iterator | |||||
// Closes the connection. | |||||
Close() | |||||
// Creates a batch for atomic updates. | |||||
NewBatch() Batch | |||||
// For debugging | |||||
Print() | |||||
// Stats returns a map of property values for all keys and the size of the cache. | |||||
Stats() map[string]string | |||||
} | |||||
//---------------------------------------- | |||||
// Batch | |||||
type Batch interface { | |||||
SetDeleter | |||||
Write() | |||||
} | |||||
type SetDeleter interface { | |||||
Set(key, value []byte) // CONTRACT: key, value readonly []byte | |||||
Delete(key []byte) // CONTRACT: key readonly []byte | |||||
} | |||||
//---------------------------------------- | |||||
// Iterator | |||||
/* | |||||
Usage: | |||||
var itr Iterator = ... | |||||
defer itr.Close() | |||||
for ; itr.Valid(); itr.Next() { | |||||
k, v := itr.Key(); itr.Value() | |||||
// ... | |||||
} | |||||
*/ | |||||
type Iterator interface { | |||||
// The start & end (exclusive) limits to iterate over. | |||||
// If end < start, then the Iterator goes in reverse order. | |||||
// | |||||
// A domain of ([]byte{12, 13}, []byte{12, 14}) will iterate | |||||
// over anything with the prefix []byte{12, 13}. | |||||
// | |||||
// The smallest key is the empty byte array []byte{} - see BeginningKey(). | |||||
// The largest key is the nil byte array []byte(nil) - see EndingKey(). | |||||
// CONTRACT: start, end readonly []byte | |||||
Domain() (start []byte, end []byte) | |||||
// Valid returns whether the current position is valid. | |||||
// Once invalid, an Iterator is forever invalid. | |||||
Valid() bool | |||||
// Next moves the iterator to the next sequential key in the database, as | |||||
// defined by order of iteration. | |||||
// | |||||
// If Valid returns false, this method will panic. | |||||
Next() | |||||
// Key returns the key of the cursor. | |||||
// If Valid returns false, this method will panic. | |||||
// CONTRACT: key readonly []byte | |||||
Key() (key []byte) | |||||
// Value returns the value of the cursor. | |||||
// If Valid returns false, this method will panic. | |||||
// CONTRACT: value readonly []byte | |||||
Value() (value []byte) | |||||
// Close releases the Iterator. | |||||
Close() | |||||
} | |||||
// For testing convenience. | |||||
func bz(s string) []byte { | |||||
return []byte(s) | |||||
} | |||||
// We defensively turn nil keys or values into []byte{} for | |||||
// most operations. | |||||
func nonNilBytes(bz []byte) []byte { | |||||
if bz == nil { | |||||
return []byte{} | |||||
} else { | |||||
return bz | |||||
} | |||||
} |
@ -0,0 +1,60 @@ | |||||
package db | |||||
import ( | |||||
"bytes" | |||||
) | |||||
func IteratePrefix(db DB, prefix []byte) Iterator { | |||||
var start, end []byte | |||||
if len(prefix) == 0 { | |||||
start = nil | |||||
end = nil | |||||
} else { | |||||
start = cp(prefix) | |||||
end = cpIncr(prefix) | |||||
} | |||||
return db.Iterator(start, end) | |||||
} | |||||
//---------------------------------------- | |||||
func cp(bz []byte) (ret []byte) { | |||||
ret = make([]byte, len(bz)) | |||||
copy(ret, bz) | |||||
return ret | |||||
} | |||||
// CONTRACT: len(bz) > 0 | |||||
func cpIncr(bz []byte) (ret []byte) { | |||||
ret = cp(bz) | |||||
for i := len(bz) - 1; i >= 0; i-- { | |||||
if ret[i] < byte(0xFF) { | |||||
ret[i] += 1 | |||||
return | |||||
} else { | |||||
ret[i] = byte(0x00) | |||||
} | |||||
} | |||||
return nil | |||||
} | |||||
// See DB interface documentation for more information. | |||||
func IsKeyInDomain(key, start, end []byte, isReverse bool) bool { | |||||
if !isReverse { | |||||
if bytes.Compare(key, start) < 0 { | |||||
return false | |||||
} | |||||
if end != nil && bytes.Compare(end, key) <= 0 { | |||||
return false | |||||
} | |||||
return true | |||||
} else { | |||||
if start != nil && bytes.Compare(start, key) < 0 { | |||||
return false | |||||
} | |||||
if end != nil && bytes.Compare(key, end) <= 0 { | |||||
return false | |||||
} | |||||
return true | |||||
} | |||||
} |
@ -0,0 +1,93 @@ | |||||
package db | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
) | |||||
// Empty iterator for empty db. | |||||
func TestPrefixIteratorNoMatchNil(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := IteratePrefix(db, []byte("2")) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
// Empty iterator for db populated after iterator created. | |||||
func TestPrefixIteratorNoMatch1(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
itr := IteratePrefix(db, []byte("2")) | |||||
db.SetSync(bz("1"), bz("value_1")) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
// Empty iterator for prefix starting after db entry. | |||||
func TestPrefixIteratorNoMatch2(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("3"), bz("value_3")) | |||||
itr := IteratePrefix(db, []byte("4")) | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
// Iterator with single val for db with single val, starting from that val. | |||||
func TestPrefixIteratorMatch1(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
db.SetSync(bz("2"), bz("value_2")) | |||||
itr := IteratePrefix(db, bz("2")) | |||||
checkValid(t, itr, true) | |||||
checkItem(t, itr, bz("2"), bz("value_2")) | |||||
checkNext(t, itr, false) | |||||
// Once invalid... | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} | |||||
// Iterator with prefix iterates over everything with same prefix. | |||||
func TestPrefixIteratorMatches1N(t *testing.T) { | |||||
for backend, _ := range backends { | |||||
t.Run(fmt.Sprintf("Prefix w/ backend %s", backend), func(t *testing.T) { | |||||
db := newTempDB(t, backend) | |||||
// prefixed | |||||
db.SetSync(bz("a/1"), bz("value_1")) | |||||
db.SetSync(bz("a/3"), bz("value_3")) | |||||
// not | |||||
db.SetSync(bz("b/3"), bz("value_3")) | |||||
db.SetSync(bz("a-3"), bz("value_3")) | |||||
db.SetSync(bz("a.3"), bz("value_3")) | |||||
db.SetSync(bz("abcdefg"), bz("value_3")) | |||||
itr := IteratePrefix(db, bz("a/")) | |||||
checkValid(t, itr, true) | |||||
checkItem(t, itr, bz("a/1"), bz("value_1")) | |||||
checkNext(t, itr, true) | |||||
checkItem(t, itr, bz("a/3"), bz("value_3")) | |||||
// Bad! | |||||
checkNext(t, itr, false) | |||||
//Once invalid... | |||||
checkInvalid(t, itr) | |||||
}) | |||||
} | |||||
} |
@ -0,0 +1,86 @@ | |||||
package merkle | |||||
import ( | |||||
"github.com/tendermint/go-wire" | |||||
cmn "github.com/tendermint/tmlibs/common" | |||||
"golang.org/x/crypto/ripemd160" | |||||
) | |||||
type SimpleMap struct { | |||||
kvs cmn.KVPairs | |||||
sorted bool | |||||
} | |||||
func NewSimpleMap() *SimpleMap { | |||||
return &SimpleMap{ | |||||
kvs: nil, | |||||
sorted: false, | |||||
} | |||||
} | |||||
func (sm *SimpleMap) Set(key string, value interface{}) { | |||||
sm.sorted = false | |||||
// Is value Hashable? | |||||
var vBytes []byte | |||||
if hashable, ok := value.(Hashable); ok { | |||||
vBytes = hashable.Hash() | |||||
} else { | |||||
vBytes = wire.BinaryBytes(value) | |||||
} | |||||
sm.kvs = append(sm.kvs, cmn.KVPair{ | |||||
Key: []byte(key), | |||||
Value: vBytes, | |||||
}) | |||||
} | |||||
// Merkle root hash of items sorted by key. | |||||
// NOTE: Behavior is undefined when key is duplicate. | |||||
func (sm *SimpleMap) Hash() []byte { | |||||
sm.Sort() | |||||
return hashKVPairs(sm.kvs) | |||||
} | |||||
func (sm *SimpleMap) Sort() { | |||||
if sm.sorted { | |||||
return | |||||
} | |||||
sm.kvs.Sort() | |||||
sm.sorted = true | |||||
} | |||||
// Returns a copy of sorted KVPairs. | |||||
// CONTRACT: The returned slice must not be mutated. | |||||
func (sm *SimpleMap) KVPairs() cmn.KVPairs { | |||||
sm.Sort() | |||||
kvs := make(cmn.KVPairs, len(sm.kvs)) | |||||
copy(kvs, sm.kvs) | |||||
return kvs | |||||
} | |||||
//---------------------------------------- | |||||
// A local extension to KVPair that can be hashed. | |||||
type kvPair cmn.KVPair | |||||
func (kv kvPair) Hash() []byte { | |||||
hasher, n, err := ripemd160.New(), new(int), new(error) | |||||
wire.WriteByteSlice(kv.Key, hasher, n, err) | |||||
if *err != nil { | |||||
panic(*err) | |||||
} | |||||
wire.WriteByteSlice(kv.Value, hasher, n, err) | |||||
if *err != nil { | |||||
panic(*err) | |||||
} | |||||
return hasher.Sum(nil) | |||||
} | |||||
func hashKVPairs(kvs cmn.KVPairs) []byte { | |||||
kvsH := make([]Hashable, 0, len(kvs)) | |||||
for _, kvp := range kvs { | |||||
kvsH = append(kvsH, kvPair(kvp)) | |||||
} | |||||
return SimpleHashFromHashables(kvsH) | |||||
} |
@ -0,0 +1,47 @@ | |||||
package merkle | |||||
import ( | |||||
"fmt" | |||||
"testing" | |||||
"github.com/stretchr/testify/assert" | |||||
) | |||||
func TestSimpleMap(t *testing.T) { | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key1", "value1") | |||||
assert.Equal(t, "3bb53f017d2f5b4f144692aa829a5c245ac2b123", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key1", "value2") | |||||
assert.Equal(t, "14a68db29e3f930ffaafeff5e07c17a439384f39", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key1", "value1") | |||||
db.Set("key2", "value2") | |||||
assert.Equal(t, "275c6367f4be335f9c482b6ef72e49c84e3f8bda", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key2", "value2") // NOTE: out of order | |||||
db.Set("key1", "value1") | |||||
assert.Equal(t, "275c6367f4be335f9c482b6ef72e49c84e3f8bda", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key1", "value1") | |||||
db.Set("key2", "value2") | |||||
db.Set("key3", "value3") | |||||
assert.Equal(t, "48d60701cb4c96916f68a958b3368205ebe3809b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
{ | |||||
db := NewSimpleMap() | |||||
db.Set("key2", "value2") // NOTE: out of order | |||||
db.Set("key1", "value1") | |||||
db.Set("key3", "value3") | |||||
assert.Equal(t, "48d60701cb4c96916f68a958b3368205ebe3809b", fmt.Sprintf("%x", db.Hash()), "Hash didn't match") | |||||
} | |||||
} |
@ -0,0 +1,131 @@ | |||||
package merkle | |||||
import ( | |||||
"bytes" | |||||
"fmt" | |||||
) | |||||
type SimpleProof struct { | |||||
Aunts [][]byte `json:"aunts"` // Hashes from leaf's sibling to a root's child. | |||||
} | |||||
// proofs[0] is the proof for items[0]. | |||||
func SimpleProofsFromHashables(items []Hashable) (rootHash []byte, proofs []*SimpleProof) { | |||||
trails, rootSPN := trailsFromHashables(items) | |||||
rootHash = rootSPN.Hash | |||||
proofs = make([]*SimpleProof, len(items)) | |||||
for i, trail := range trails { | |||||
proofs[i] = &SimpleProof{ | |||||
Aunts: trail.FlattenAunts(), | |||||
} | |||||
} | |||||
return | |||||
} | |||||
// Verify that leafHash is a leaf hash of the simple-merkle-tree | |||||
// which hashes to rootHash. | |||||
func (sp *SimpleProof) Verify(index int, total int, leafHash []byte, rootHash []byte) bool { | |||||
computedHash := computeHashFromAunts(index, total, leafHash, sp.Aunts) | |||||
return computedHash != nil && bytes.Equal(computedHash, rootHash) | |||||
} | |||||
func (sp *SimpleProof) String() string { | |||||
return sp.StringIndented("") | |||||
} | |||||
func (sp *SimpleProof) StringIndented(indent string) string { | |||||
return fmt.Sprintf(`SimpleProof{ | |||||
%s Aunts: %X | |||||
%s}`, | |||||
indent, sp.Aunts, | |||||
indent) | |||||
} | |||||
// Use the leafHash and innerHashes to get the root merkle hash. | |||||
// If the length of the innerHashes slice isn't exactly correct, the result is nil. | |||||
func computeHashFromAunts(index int, total int, leafHash []byte, innerHashes [][]byte) []byte { | |||||
// Recursive impl. | |||||
if index >= total { | |||||
return nil | |||||
} | |||||
switch total { | |||||
case 0: | |||||
panic("Cannot call computeHashFromAunts() with 0 total") | |||||
case 1: | |||||
if len(innerHashes) != 0 { | |||||
return nil | |||||
} | |||||
return leafHash | |||||
default: | |||||
if len(innerHashes) == 0 { | |||||
return nil | |||||
} | |||||
numLeft := (total + 1) / 2 | |||||
if index < numLeft { | |||||
leftHash := computeHashFromAunts(index, numLeft, leafHash, innerHashes[:len(innerHashes)-1]) | |||||
if leftHash == nil { | |||||
return nil | |||||
} | |||||
return SimpleHashFromTwoHashes(leftHash, innerHashes[len(innerHashes)-1]) | |||||
} else { | |||||
rightHash := computeHashFromAunts(index-numLeft, total-numLeft, leafHash, innerHashes[:len(innerHashes)-1]) | |||||
if rightHash == nil { | |||||
return nil | |||||
} | |||||
return SimpleHashFromTwoHashes(innerHashes[len(innerHashes)-1], rightHash) | |||||
} | |||||
} | |||||
} | |||||
// Helper structure to construct merkle proof. | |||||
// The node and the tree is thrown away afterwards. | |||||
// Exactly one of node.Left and node.Right is nil, unless node is the root, in which case both are nil. | |||||
// node.Parent.Hash = hash(node.Hash, node.Right.Hash) or | |||||
// hash(node.Left.Hash, node.Hash), depending on whether node is a left/right child. | |||||
type SimpleProofNode struct { | |||||
Hash []byte | |||||
Parent *SimpleProofNode | |||||
Left *SimpleProofNode // Left sibling (only one of Left,Right is set) | |||||
Right *SimpleProofNode // Right sibling (only one of Left,Right is set) | |||||
} | |||||
// Starting from a leaf SimpleProofNode, FlattenAunts() will return | |||||
// the inner hashes for the item corresponding to the leaf. | |||||
func (spn *SimpleProofNode) FlattenAunts() [][]byte { | |||||
// Nonrecursive impl. | |||||
innerHashes := [][]byte{} | |||||
for spn != nil { | |||||
if spn.Left != nil { | |||||
innerHashes = append(innerHashes, spn.Left.Hash) | |||||
} else if spn.Right != nil { | |||||
innerHashes = append(innerHashes, spn.Right.Hash) | |||||
} else { | |||||
break | |||||
} | |||||
spn = spn.Parent | |||||
} | |||||
return innerHashes | |||||
} | |||||
// trails[0].Hash is the leaf hash for items[0]. | |||||
// trails[i].Parent.Parent....Parent == root for all i. | |||||
func trailsFromHashables(items []Hashable) (trails []*SimpleProofNode, root *SimpleProofNode) { | |||||
// Recursive impl. | |||||
switch len(items) { | |||||
case 0: | |||||
return nil, nil | |||||
case 1: | |||||
trail := &SimpleProofNode{items[0].Hash(), nil, nil, nil} | |||||
return []*SimpleProofNode{trail}, trail | |||||
default: | |||||
lefts, leftRoot := trailsFromHashables(items[:(len(items)+1)/2]) | |||||
rights, rightRoot := trailsFromHashables(items[(len(items)+1)/2:]) | |||||
rootHash := SimpleHashFromTwoHashes(leftRoot.Hash, rightRoot.Hash) | |||||
root := &SimpleProofNode{rootHash, nil, nil, nil} | |||||
leftRoot.Parent = root | |||||
leftRoot.Right = rightRoot | |||||
rightRoot.Parent = root | |||||
rightRoot.Left = leftRoot | |||||
return append(lefts, rights...), root | |||||
} | |||||
} |