diff --git a/binary/binary.go b/binary/binary.go index 3e4f107a4..5bc8357ff 100644 --- a/binary/binary.go +++ b/binary/binary.go @@ -6,11 +6,20 @@ type Binary interface { WriteTo(w io.Writer) (int64, error) } -func WriteTo(b Binary, w io.Writer, n int64, err error) (int64, error) { - if err != nil { - return n, err +func WriteTo(w io.Writer, bz []byte, n *int64, err *error) { + if *err != nil { + return } - var n_ int64 - n_, err = b.WriteTo(w) - return n + n_, err + n_, err_ := w.Write(bz) + *n += int64(n_) + *err = err_ +} + +func ReadFull(r io.Reader, buf []byte, n *int64, err *error) { + if *err != nil { + return + } + n_, err_ := io.ReadFull(r, buf) + *n += int64(n_) + *err = err_ } diff --git a/binary/byteslice.go b/binary/byteslice.go index 7c1c495b1..e3ca14458 100644 --- a/binary/byteslice.go +++ b/binary/byteslice.go @@ -1,71 +1,22 @@ package binary -import "io" -import "bytes" +import ( + "io" +) -type ByteSlice []byte +// ByteSlice -func (self ByteSlice) Equals(other interface{}) bool { - if o, ok := other.(ByteSlice); ok { - return bytes.Equal(self, o) - } else { - return false - } -} - -func (self ByteSlice) Less(other interface{}) bool { - if o, ok := other.(ByteSlice); ok { - return bytes.Compare(self, o) < 0 // -1 if a < b - } else { - panic("Cannot compare unequal types") - } -} - -func (self ByteSlice) ByteSize() int { - return len(self) + 4 -} - -func (self ByteSlice) WriteTo(w io.Writer) (n int64, err error) { - var n_ int - _, err = UInt32(len(self)).WriteTo(w) - if err != nil { - return n, err - } - n_, err = w.Write([]byte(self)) - return int64(n_ + 4), err -} - -func (self ByteSlice) Reader() io.Reader { - return bytes.NewReader([]byte(self)) -} - -func ReadByteSliceSafe(r io.Reader) (bytes ByteSlice, n int64, err error) { - length, n_, err := ReadUInt32Safe(r) - n += n_ - if err != nil { - return nil, n, err - } - bytes = make([]byte, int(length)) - n__, err := io.ReadFull(r, bytes) - n += int64(n__) - if err != nil { - return nil, n, err - } - return bytes, n, nil -} - -func ReadByteSliceN(r io.Reader) (bytes ByteSlice, n int64) { - bytes, n, err := ReadByteSliceSafe(r) - if err != nil { - panic(err) - } - return bytes, n +func WriteByteSlice(w io.Writer, bz []byte, n *int64, err *error) { + WriteUInt32(w, uint32(len(bz)), n, err) + WriteTo(w, bz, n, err) } -func ReadByteSlice(r io.Reader) (bytes ByteSlice) { - bytes, _, err := ReadByteSliceSafe(r) - if err != nil { - panic(err) +func ReadByteSlice(r io.Reader, n *int64, err *error) []byte { + length := ReadUInt32(r, n, err) + if *err != nil { + return nil } - return bytes + buf := make([]byte, int(length)) + ReadFull(r, buf, n, err) + return buf } diff --git a/binary/codec.go b/binary/codec.go index 432ad0bea..d585f1035 100644 --- a/binary/codec.go +++ b/binary/codec.go @@ -1,91 +1,138 @@ package binary import ( + "errors" "io" + "time" ) +type Codec interface { + WriteTo(io.Writer, interface{}, *int64, *error) + ReadFrom(io.Reader, *int64, *error) interface{} +} + +//----------------------------------------------------------------------------- + const ( - TYPE_NIL = Byte(0x00) - TYPE_BYTE = Byte(0x01) - TYPE_INT8 = Byte(0x02) - TYPE_UINT8 = Byte(0x03) - TYPE_INT16 = Byte(0x04) - TYPE_UINT16 = Byte(0x05) - TYPE_INT32 = Byte(0x06) - TYPE_UINT32 = Byte(0x07) - TYPE_INT64 = Byte(0x08) - TYPE_UINT64 = Byte(0x09) - TYPE_STRING = Byte(0x10) - TYPE_BYTESLICE = Byte(0x11) - TYPE_TIME = Byte(0x20) + typeNil = byte(0x00) + typeByte = byte(0x01) + typeInt8 = byte(0x02) + // typeUInt8 = byte(0x03) + typeInt16 = byte(0x04) + typeUInt16 = byte(0x05) + typeInt32 = byte(0x06) + typeUInt32 = byte(0x07) + typeInt64 = byte(0x08) + typeUInt64 = byte(0x09) + typeString = byte(0x10) + typeByteSlice = byte(0x11) + typeTime = byte(0x20) ) -func GetBinaryType(o Binary) Byte { +var BasicCodec = basicCodec{} + +type basicCodec struct{} + +func (bc basicCodec) WriteTo(w io.Writer, o interface{}, n *int64, err *error) { switch o.(type) { case nil: - return TYPE_NIL - case Byte: - return TYPE_BYTE - case Int8: - return TYPE_INT8 - case UInt8: - return TYPE_UINT8 - case Int16: - return TYPE_INT16 - case UInt16: - return TYPE_UINT16 - case Int32: - return TYPE_INT32 - case UInt32: - return TYPE_UINT32 - case Int64: - return TYPE_INT64 - case UInt64: - return TYPE_UINT64 - case String: - return TYPE_STRING - case ByteSlice: - return TYPE_BYTESLICE - case Time: - return TYPE_TIME + WriteByte(w, typeNil, n, err) + case byte: + WriteByte(w, typeByte, n, err) + WriteByte(w, o.(byte), n, err) + case int8: + WriteByte(w, typeInt8, n, err) + WriteInt8(w, o.(int8), n, err) + //case uint8: + // WriteByte(w, typeUInt8, n, err) + // WriteUInt8(w, o.(uint8), n, err) + case int16: + WriteByte(w, typeInt16, n, err) + WriteInt16(w, o.(int16), n, err) + case uint16: + WriteByte(w, typeUInt16, n, err) + WriteUInt16(w, o.(uint16), n, err) + case int32: + WriteByte(w, typeInt32, n, err) + WriteInt32(w, o.(int32), n, err) + case uint32: + WriteByte(w, typeUInt32, n, err) + WriteUInt32(w, o.(uint32), n, err) + case int64: + WriteByte(w, typeInt64, n, err) + WriteInt64(w, o.(int64), n, err) + case uint64: + WriteByte(w, typeUInt64, n, err) + WriteUInt64(w, o.(uint64), n, err) + case string: + WriteByte(w, typeString, n, err) + WriteString(w, o.(string), n, err) + case []byte: + WriteByte(w, typeByteSlice, n, err) + WriteByteSlice(w, o.([]byte), n, err) + case time.Time: + WriteByte(w, typeTime, n, err) + WriteTime(w, o.(time.Time), n, err) default: panic("Unsupported type") } + return } -func ReadBinaryN(r io.Reader) (o Binary, n int64) { - type_, n_ := ReadByteN(r) - n += n_ +func (bc basicCodec) ReadFrom(r io.Reader, n *int64, err *error) interface{} { + type_ := ReadByte(r, n, err) switch type_ { - case TYPE_NIL: - o, n_ = nil, 0 - case TYPE_BYTE: - o, n_ = ReadByteN(r) - case TYPE_INT8: - o, n_ = ReadInt8N(r) - case TYPE_UINT8: - o, n_ = ReadUInt8N(r) - case TYPE_INT16: - o, n_ = ReadInt16N(r) - case TYPE_UINT16: - o, n_ = ReadUInt16N(r) - case TYPE_INT32: - o, n_ = ReadInt32N(r) - case TYPE_UINT32: - o, n_ = ReadUInt32N(r) - case TYPE_INT64: - o, n_ = ReadInt64N(r) - case TYPE_UINT64: - o, n_ = ReadUInt64N(r) - case TYPE_STRING: - o, n_ = ReadStringN(r) - case TYPE_BYTESLICE: - o, n_ = ReadByteSliceN(r) - case TYPE_TIME: - o, n_ = ReadTimeN(r) + case typeNil: + return nil + case typeByte: + return ReadByte(r, n, err) + case typeInt8: + return ReadInt8(r, n, err) + //case typeUInt8: + // return ReadUInt8(r, n, err) + case typeInt16: + return ReadInt16(r, n, err) + case typeUInt16: + return ReadUInt16(r, n, err) + case typeInt32: + return ReadInt32(r, n, err) + case typeUInt32: + return ReadUInt32(r, n, err) + case typeInt64: + return ReadInt64(r, n, err) + case typeUInt64: + return ReadUInt64(r, n, err) + case typeString: + return ReadString(r, n, err) + case typeByteSlice: + return ReadByteSlice(r, n, err) + case typeTime: + return ReadTime(r, n, err) default: panic("Unsupported type") } - n += n_ - return o, n +} + +//----------------------------------------------------------------------------- + +// Creates an adapter codec for Binary things. +// Resulting Codec can be used with merkle/*. +type BinaryCodec struct { + decoder func(io.Reader, *int64, *error) interface{} +} + +func NewBinaryCodec(decoder func(io.Reader, *int64, *error) interface{}) *BinaryCodec { + return &BinaryCodec{decoder} +} + +func (ca *BinaryCodec) WriteTo(w io.Writer, o interface{}, n *int64, err *error) { + if bo, ok := o.(Binary); ok { + WriteTo(w, BinaryBytes(bo), n, err) + } else { + *err = errors.New("BinaryCodec expected Binary object") + } +} + +func (ca *BinaryCodec) ReadFrom(r io.Reader, n *int64, err *error) interface{} { + return ca.decoder(r, n, err) } diff --git a/binary/int.go b/binary/int.go index db92d3fa1..917239ebd 100644 --- a/binary/int.go +++ b/binary/int.go @@ -5,494 +5,118 @@ import ( "io" ) -type Byte byte -type Int8 int8 -type UInt8 uint8 -type Int16 int16 -type UInt16 uint16 -type Int32 int32 -type UInt32 uint32 -type Int64 int64 -type UInt64 uint64 -type Int int -type UInt uint - // Byte -func (self Byte) Equals(other interface{}) bool { - return self == other -} - -func (self Byte) Less(other interface{}) bool { - if o, ok := other.(Byte); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self Byte) ByteSize() int { - return 1 -} - -func (self Byte) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err +func WriteByte(w io.Writer, b byte, n *int64, err *error) { + WriteTo(w, []byte{b}, n, err) } -func ReadByteSafe(r io.Reader) (Byte, int64, error) { - buf := [1]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return 0, int64(n), err - } - return Byte(buf[0]), int64(n), nil -} - -func ReadByteN(r io.Reader) (Byte, int64) { - b, n, err := ReadByteSafe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadByte(r io.Reader) Byte { - b, _, err := ReadByteSafe(r) - if err != nil { - panic(err) - } - return b -} - -func Readbyte(r io.Reader) byte { - return byte(ReadByte(r)) +func ReadByte(r io.Reader, n *int64, err *error) byte { + buf := make([]byte, 1) + ReadFull(r, buf, n, err) + return buf[0] } // Int8 -func (self Int8) Equals(other interface{}) bool { - return self == other -} - -func (self Int8) Less(other interface{}) bool { - if o, ok := other.(Int8); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self Int8) ByteSize() int { - return 1 -} - -func (self Int8) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err -} - -func ReadInt8Safe(r io.Reader) (Int8, int64, error) { - buf := [1]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return Int8(0), int64(n), err - } - return Int8(buf[0]), int64(n), nil +func WriteInt8(w io.Writer, i int8, n *int64, err *error) { + WriteByte(w, byte(i), n, err) } -func ReadInt8N(r io.Reader) (Int8, int64) { - b, n, err := ReadInt8Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadInt8(r io.Reader) Int8 { - b, _, err := ReadInt8Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readint8(r io.Reader) int8 { - return int8(ReadInt8(r)) +func ReadInt8(r io.Reader, n *int64, err *error) int8 { + return int8(ReadByte(r, n, err)) } // UInt8 -func (self UInt8) Equals(other interface{}) bool { - return self == other +func WriteUInt8(w io.Writer, i uint8, n *int64, err *error) { + WriteByte(w, byte(i), n, err) } -func (self UInt8) Less(other interface{}) bool { - if o, ok := other.(UInt8); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self UInt8) ByteSize() int { - return 1 -} - -func (self UInt8) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write([]byte{byte(self)}) - return int64(n), err -} - -func ReadUInt8Safe(r io.Reader) (UInt8, int64, error) { - buf := [1]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return UInt8(0), int64(n), err - } - return UInt8(buf[0]), int64(n), nil -} - -func ReadUInt8N(r io.Reader) (UInt8, int64) { - b, n, err := ReadUInt8Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadUInt8(r io.Reader) UInt8 { - b, _, err := ReadUInt8Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readuint8(r io.Reader) uint8 { - return uint8(ReadUInt8(r)) +func ReadUInt8(r io.Reader, n *int64, err *error) uint8 { + return uint8(ReadByte(r, n, err)) } // Int16 -func (self Int16) Equals(other interface{}) bool { - return self == other -} - -func (self Int16) Less(other interface{}) bool { - if o, ok := other.(Int16); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } +func WriteInt16(w io.Writer, i int16, n *int64, err *error) { + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(i)) + WriteTo(w, buf, n, err) } -func (self Int16) ByteSize() int { - return 2 -} - -func (self Int16) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0} - binary.LittleEndian.PutUint16(buf, uint16(self)) - n, err := w.Write(buf) - return int64(n), err -} - -func ReadInt16Safe(r io.Reader) (Int16, int64, error) { - buf := [2]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return Int16(0), int64(n), err - } - return Int16(binary.LittleEndian.Uint16(buf[:])), int64(n), nil -} - -func ReadInt16N(r io.Reader) (Int16, int64) { - b, n, err := ReadInt16Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadInt16(r io.Reader) Int16 { - b, _, err := ReadInt16Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readint16(r io.Reader) int16 { - return int16(ReadInt16(r)) +func ReadInt16(r io.Reader, n *int64, err *error) int16 { + buf := make([]byte, 2) + ReadFull(r, buf, n, err) + return int16(binary.LittleEndian.Uint16(buf)) } // UInt16 -func (self UInt16) Equals(other interface{}) bool { - return self == other -} - -func (self UInt16) Less(other interface{}) bool { - if o, ok := other.(UInt16); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self UInt16) ByteSize() int { - return 2 -} - -func (self UInt16) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0} - binary.LittleEndian.PutUint16(buf, uint16(self)) - n, err := w.Write(buf) - return int64(n), err +func WriteUInt16(w io.Writer, i uint16, n *int64, err *error) { + buf := make([]byte, 2) + binary.LittleEndian.PutUint16(buf, uint16(i)) + WriteTo(w, buf, n, err) } -func ReadUInt16Safe(r io.Reader) (UInt16, int64, error) { - buf := [2]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return UInt16(0), int64(n), err - } - return UInt16(binary.LittleEndian.Uint16(buf[:])), int64(n), nil -} - -func ReadUInt16N(r io.Reader) (UInt16, int64) { - b, n, err := ReadUInt16Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadUInt16(r io.Reader) UInt16 { - b, _, err := ReadUInt16Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readuint16(r io.Reader) uint16 { - return uint16(ReadUInt16(r)) +func ReadUInt16(r io.Reader, n *int64, err *error) uint16 { + buf := make([]byte, 2) + ReadFull(r, buf, n, err) + return uint16(binary.LittleEndian.Uint16(buf)) } // Int32 -func (self Int32) Equals(other interface{}) bool { - return self == other -} - -func (self Int32) Less(other interface{}) bool { - if o, ok := other.(Int32); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self Int32) ByteSize() int { - return 4 -} - -func (self Int32) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0, 0, 0} - binary.LittleEndian.PutUint32(buf, uint32(self)) - n, err := w.Write(buf) - return int64(n), err -} - -func ReadInt32Safe(r io.Reader) (Int32, int64, error) { - buf := [4]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return Int32(0), int64(n), err - } - return Int32(binary.LittleEndian.Uint32(buf[:])), int64(n), nil +func WriteInt32(w io.Writer, i int32, n *int64, err *error) { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(i)) + WriteTo(w, buf, n, err) } -func ReadInt32N(r io.Reader) (Int32, int64) { - b, n, err := ReadInt32Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadInt32(r io.Reader) Int32 { - b, _, err := ReadInt32Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readint32(r io.Reader) int32 { - return int32(ReadInt32(r)) +func ReadInt32(r io.Reader, n *int64, err *error) int32 { + buf := make([]byte, 4) + ReadFull(r, buf, n, err) + return int32(binary.LittleEndian.Uint32(buf)) } // UInt32 -func (self UInt32) Equals(other interface{}) bool { - return self == other -} - -func (self UInt32) Less(other interface{}) bool { - if o, ok := other.(UInt32); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self UInt32) ByteSize() int { - return 4 -} - -func (self UInt32) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0, 0, 0} - binary.LittleEndian.PutUint32(buf, uint32(self)) - n, err := w.Write(buf) - return int64(n), err -} - -func ReadUInt32Safe(r io.Reader) (UInt32, int64, error) { - buf := [4]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return UInt32(0), int64(n), err - } - return UInt32(binary.LittleEndian.Uint32(buf[:])), int64(n), nil -} - -func ReadUInt32N(r io.Reader) (UInt32, int64) { - b, n, err := ReadUInt32Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadUInt32(r io.Reader) UInt32 { - b, _, err := ReadUInt32Safe(r) - if err != nil { - panic(err) - } - return b +func WriteUInt32(w io.Writer, i uint32, n *int64, err *error) { + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, uint32(i)) + WriteTo(w, buf, n, err) } -func Readuint32(r io.Reader) uint32 { - return uint32(ReadUInt32(r)) +func ReadUInt32(r io.Reader, n *int64, err *error) uint32 { + buf := make([]byte, 4) + ReadFull(r, buf, n, err) + return uint32(binary.LittleEndian.Uint32(buf)) } // Int64 -func (self Int64) Equals(other interface{}) bool { - return self == other -} - -func (self Int64) Less(other interface{}) bool { - if o, ok := other.(Int64); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self Int64) ByteSize() int { - return 8 +func WriteInt64(w io.Writer, i int64, n *int64, err *error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + WriteTo(w, buf, n, err) } -func (self Int64) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0, 0, 0, 0, 0, 0, 0} - binary.LittleEndian.PutUint64(buf, uint64(self)) - n, err := w.Write(buf) - return int64(n), err -} - -func ReadInt64Safe(r io.Reader) (Int64, int64, error) { - buf := [8]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return Int64(0), int64(n), err - } - return Int64(binary.LittleEndian.Uint64(buf[:])), int64(n), nil -} - -func ReadInt64N(r io.Reader) (Int64, int64) { - b, n, err := ReadInt64Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadInt64(r io.Reader) Int64 { - b, _, err := ReadInt64Safe(r) - if err != nil { - panic(err) - } - return b -} - -func Readint64(r io.Reader) int64 { - return int64(ReadInt64(r)) +func ReadInt64(r io.Reader, n *int64, err *error) int64 { + buf := make([]byte, 8) + ReadFull(r, buf, n, err) + return int64(binary.LittleEndian.Uint64(buf)) } // UInt64 -func (self UInt64) Equals(other interface{}) bool { - return self == other -} - -func (self UInt64) Less(other interface{}) bool { - if o, ok := other.(UInt64); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self UInt64) ByteSize() int { - return 8 -} - -func (self UInt64) WriteTo(w io.Writer) (int64, error) { - buf := []byte{0, 0, 0, 0, 0, 0, 0, 0} - binary.LittleEndian.PutUint64(buf, uint64(self)) - n, err := w.Write(buf) - return int64(n), err -} - -func ReadUInt64Safe(r io.Reader) (UInt64, int64, error) { - buf := [8]byte{0} - n, err := io.ReadFull(r, buf[:]) - if err != nil { - return UInt64(0), int64(n), err - } - return UInt64(binary.LittleEndian.Uint64(buf[:])), int64(n), nil -} - -func ReadUInt64N(r io.Reader) (UInt64, int64) { - b, n, err := ReadUInt64Safe(r) - if err != nil { - panic(err) - } - return b, n -} - -func ReadUInt64(r io.Reader) UInt64 { - b, _, err := ReadUInt64Safe(r) - if err != nil { - panic(err) - } - return b +func WriteUInt64(w io.Writer, i uint64, n *int64, err *error) { + buf := make([]byte, 8) + binary.LittleEndian.PutUint64(buf, uint64(i)) + WriteTo(w, buf, n, err) } -func Readuint64(r io.Reader) uint64 { - return uint64(ReadUInt64(r)) +func ReadUInt64(r io.Reader, n *int64, err *error) uint64 { + buf := make([]byte, 8) + ReadFull(r, buf, n, err) + return uint64(binary.LittleEndian.Uint64(buf)) } diff --git a/binary/string.go b/binary/string.go index 678e61a25..3ad25d0d8 100644 --- a/binary/string.go +++ b/binary/string.go @@ -2,67 +2,19 @@ package binary import "io" -type String string - // String -func (self String) Equals(other interface{}) bool { - return self == other -} - -func (self String) Less(other interface{}) bool { - if o, ok := other.(String); ok { - return self < o - } else { - panic("Cannot compare unequal types") - } -} - -func (self String) ByteSize() int { - return len(self) + 4 -} - -func (self String) WriteTo(w io.Writer) (n int64, err error) { - var n_ int - _, err = UInt32(len(self)).WriteTo(w) - if err != nil { - return n, err - } - n_, err = w.Write([]byte(self)) - return int64(n_ + 4), err -} - -func ReadStringSafe(r io.Reader) (str String, n int64, err error) { - length, n_, err := ReadUInt32Safe(r) - n += n_ - if err != nil { - return "", n, err - } - bytes := make([]byte, int(length)) - n__, err := io.ReadFull(r, bytes) - n += int64(n__) - if err != nil { - return "", n, err - } - return String(bytes), n, nil +func WriteString(w io.Writer, s string, n *int64, err *error) { + WriteUInt32(w, uint32(len(s)), n, err) + WriteTo(w, []byte(s), n, err) } -func ReadStringN(r io.Reader) (str String, n int64) { - str, n, err := ReadStringSafe(r) - if err != nil { - panic(err) +func ReadString(r io.Reader, n *int64, err *error) string { + length := ReadUInt32(r, n, err) + if *err != nil { + return "" } - return str, n -} - -func ReadString(r io.Reader) (str String) { - str, _, err := ReadStringSafe(r) - if err != nil { - panic(err) - } - return str -} - -func Readstring(r io.Reader) (str string) { - return string(ReadString(r)) + buf := make([]byte, int(length)) + ReadFull(r, buf, n, err) + return string(buf) } diff --git a/binary/time.go b/binary/time.go index d64284441..0c5a020be 100644 --- a/binary/time.go +++ b/binary/time.go @@ -5,58 +5,13 @@ import ( "time" ) -type Time struct { - time.Time -} - -func TimeFromUnix(secSinceEpoch int64) Time { - return Time{time.Unix(secSinceEpoch, 0)} -} - -func (self Time) Equals(other interface{}) bool { - if o, ok := other.(Time); ok { - return self.Equal(o.Time) - } else { - return false - } -} - -func (self Time) Less(other interface{}) bool { - if o, ok := other.(Time); ok { - return self.Before(o.Time) - } else { - panic("Cannot compare unequal types") - } -} - -func (self Time) ByteSize() int { - return 8 -} - -func (self Time) WriteTo(w io.Writer) (int64, error) { - return Int64(self.Unix()).WriteTo(w) -} - -func ReadTimeSafe(r io.Reader) (Time, int64, error) { - t, n, err := ReadInt64Safe(r) - if err != nil { - return Time{}, n, err - } - return Time{time.Unix(int64(t), 0)}, n, nil -} +// Time -func ReadTimeN(r io.Reader) (Time, int64) { - t, n, err := ReadTimeSafe(r) - if err != nil { - panic(err) - } - return t, n +func WriteTime(w io.Writer, t time.Time, n *int64, err *error) { + WriteInt64(w, t.Unix(), n, err) } -func ReadTime(r io.Reader) Time { - t, _, err := ReadTimeSafe(r) - if err != nil { - panic(err) - } - return t +func ReadTime(r io.Reader, n *int64, err *error) time.Time { + t := ReadInt64(r, n, err) + return time.Unix(t, 0) } diff --git a/binary/util.go b/binary/util.go index e08b8b56b..e8a52117b 100644 --- a/binary/util.go +++ b/binary/util.go @@ -5,10 +5,10 @@ import ( "crypto/sha256" ) -func BinaryBytes(b Binary) ByteSlice { +func BinaryBytes(b Binary) []byte { buf := bytes.NewBuffer(nil) b.WriteTo(buf) - return ByteSlice(buf.Bytes()) + return buf.Bytes() } // NOTE: does not care about the type, only the binary representation. @@ -25,11 +25,11 @@ func BinaryCompare(a, b Binary) int { return bytes.Compare(aBytes, bBytes) } -func BinaryHash(b Binary) ByteSlice { +func BinaryHash(b Binary) []byte { hasher := sha256.New() _, err := b.WriteTo(hasher) if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) + return hasher.Sum(nil) } diff --git a/merkle/iavl_node.go b/merkle/iavl_node.go index 3a80e7f81..a94b557dc 100644 --- a/merkle/iavl_node.go +++ b/merkle/iavl_node.go @@ -10,11 +10,11 @@ import ( // Node type IAVLNode struct { - key Key - value Value + key []byte + value []byte size uint64 height uint8 - hash ByteSlice + hash []byte left *IAVLNode right *IAVLNode @@ -27,7 +27,7 @@ const ( IAVLNODE_FLAG_PLACEHOLDER = byte(0x02) ) -func NewIAVLNode(key Key, value Value) *IAVLNode { +func NewIAVLNode(key []byte, value []byte) *IAVLNode { return &IAVLNode{ key: key, value: value, @@ -50,14 +50,6 @@ func (self *IAVLNode) Copy() *IAVLNode { } } -func (self *IAVLNode) Key() Key { - return self.key -} - -func (self *IAVLNode) Value() Value { - return self.value -} - func (self *IAVLNode) Size() uint64 { return self.size } @@ -66,14 +58,14 @@ func (self *IAVLNode) Height() uint8 { return self.height } -func (self *IAVLNode) has(db Db, key Key) (has bool) { - if self.key.Equals(key) { +func (self *IAVLNode) has(db Db, key []byte) (has bool) { + if bytes.Equal(self.key, key) { return true } if self.height == 0 { return false } else { - if key.Less(self.key) { + if bytes.Compare(key, self.key) == -1 { return self.leftFilled(db).has(db, key) } else { return self.rightFilled(db).has(db, key) @@ -81,15 +73,15 @@ func (self *IAVLNode) has(db Db, key Key) (has bool) { } } -func (self *IAVLNode) get(db Db, key Key) (value Value) { +func (self *IAVLNode) get(db Db, key []byte) (value []byte) { if self.height == 0 { - if self.key.Equals(key) { + if bytes.Equal(self.key, key) { return self.value } else { return nil } } else { - if key.Less(self.key) { + if bytes.Compare(key, self.key) == -1 { return self.leftFilled(db).get(db, key) } else { return self.rightFilled(db).get(db, key) @@ -97,7 +89,7 @@ func (self *IAVLNode) get(db Db, key Key) (value Value) { } } -func (self *IAVLNode) Hash() (ByteSlice, uint64) { +func (self *IAVLNode) HashWithCount() ([]byte, uint64) { if self.hash != nil { return self.hash, 0 } @@ -138,9 +130,9 @@ func (self *IAVLNode) Save(db Db) { self.flags |= IAVLNODE_FLAG_PERSISTED } -func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated bool) { +func (self *IAVLNode) set(db Db, key []byte, value []byte) (_ *IAVLNode, updated bool) { if self.height == 0 { - if key.Less(self.key) { + if bytes.Compare(key, self.key) == -1 { return &IAVLNode{ key: self.key, height: 1, @@ -148,7 +140,7 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo left: NewIAVLNode(key, value), right: self, }, false - } else if self.key.Equals(key) { + } else if bytes.Equal(self.key, key) { return NewIAVLNode(key, value), true } else { return &IAVLNode{ @@ -161,7 +153,7 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo } } else { self = self.Copy() - if key.Less(self.key) { + if bytes.Compare(key, self.key) == -1 { self.left, updated = self.leftFilled(db).set(db, key, value) } else { self.right, updated = self.rightFilled(db).set(db, key, value) @@ -176,15 +168,15 @@ func (self *IAVLNode) set(db Db, key Key, value Value) (_ *IAVLNode, updated boo } // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed. -func (self *IAVLNode) remove(db Db, key Key) (newSelf *IAVLNode, newKey Key, value Value, err error) { +func (self *IAVLNode) remove(db Db, key []byte) (newSelf *IAVLNode, newKey []byte, value []byte, err error) { if self.height == 0 { - if self.key.Equals(key) { + if bytes.Equal(self.key, key) { return nil, nil, self.value, nil } else { return self, nil, nil, NotFound(key) } } else { - if key.Less(self.key) { + if bytes.Compare(key, self.key) == -1 { var newLeft *IAVLNode newLeft, newKey, value, err = self.leftFilled(db).remove(db, key) if err != nil { @@ -220,74 +212,28 @@ func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) { } func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) { - var _n int64 - - // height & size - _n, err = UInt8(self.height).WriteTo(w) - if err != nil { - return - } else { - n += _n - } - _n, err = UInt64(self.size).WriteTo(w) - if err != nil { - return - } else { - n += _n - } - - // key - _n, err = Byte(GetBinaryType(self.key)).WriteTo(w) - if err != nil { - return - } else { - n += _n - } - _n, err = self.key.WriteTo(w) + // height & size & key + WriteUInt8(w, self.height, &n, &err) + WriteUInt64(w, self.size, &n, &err) + WriteByteSlice(w, self.key, &n, &err) if err != nil { return - } else { - n += _n } // value or children if self.height == 0 { // value - _n, err = Byte(GetBinaryType(self.value)).WriteTo(w) - if err != nil { - return - } else { - n += _n - } - if self.value != nil { - _n, err = self.value.WriteTo(w) - if err != nil { - return - } else { - n += _n - } - } + WriteByteSlice(w, self.value, &n, &err) } else { // left - leftHash, leftCount := self.left.Hash() + leftHash, leftCount := self.left.HashWithCount() hashCount += leftCount - _n, err = leftHash.WriteTo(w) - if err != nil { - return - } else { - n += _n - } + WriteByteSlice(w, leftHash, &n, &err) // right - rightHash, rightCount := self.right.Hash() + rightHash, rightCount := self.right.HashWithCount() hashCount += rightCount - _n, err = rightHash.WriteTo(w) - if err != nil { - return - } else { - n += _n - } + WriteByteSlice(w, rightHash, &n, &err) } - return } @@ -300,25 +246,30 @@ func (self *IAVLNode) fill(db Db) { } buf := db.Get(self.hash) r := bytes.NewReader(buf) - // node header - self.height = uint8(ReadUInt8(r)) - self.size = uint64(ReadUInt64(r)) - // key - key, _ := ReadBinaryN(r) - self.key = key.(Key) + var n int64 + var err error + // node header & key + self.height = ReadUInt8(r, &n, &err) + self.size = ReadUInt64(r, &n, &err) + self.key = ReadByteSlice(r, &n, &err) + if err != nil { + panic(err) + } + + // node value or children. if self.height == 0 { // value - self.value, _ = ReadBinaryN(r) + self.value = ReadByteSlice(r, &n, &err) } else { // left - leftHash := ReadByteSlice(r) + leftHash := ReadByteSlice(r, &n, &err) self.left = &IAVLNode{ hash: leftHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, } // right - rightHash := ReadByteSlice(r) + rightHash := ReadByteSlice(r, &n, &err) self.right = &IAVLNode{ hash: rightHash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, @@ -327,6 +278,9 @@ func (self *IAVLNode) fill(db Db) { panic("buf not all consumed") } } + if err != nil { + panic(err) + } self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER } @@ -425,7 +379,7 @@ func (self *IAVLNode) rmd(db Db) *IAVLNode { return self.rightFilled(db).rmd(db) } -func (self *IAVLNode) traverse(db Db, cb func(Node) bool) bool { +func (self *IAVLNode) traverse(db Db, cb func(*IAVLNode) bool) bool { stop := cb(self) if stop { return stop diff --git a/merkle/iavl_test.go b/merkle/iavl_test.go index 679f3d7f3..b1d1a7083 100644 --- a/merkle/iavl_test.go +++ b/merkle/iavl_test.go @@ -5,7 +5,6 @@ import ( "crypto/sha256" "fmt" - . "github.com/tendermint/tendermint/binary" . "github.com/tendermint/tendermint/common" "github.com/tendermint/tendermint/db" @@ -17,8 +16,8 @@ func init() { // TODO: seed rand? } -func randstr(length int) String { - return String(RandStr(length)) +func randstr(length int) string { + return RandStr(length) } func TestUnit(t *testing.T) { @@ -29,12 +28,12 @@ func TestUnit(t *testing.T) { if _, ok := l.(*IAVLNode); ok { left = l.(*IAVLNode) } else { - left = NewIAVLNode(Int32(l.(int)), nil) + left = NewIAVLNode([]byte{byte(l.(int))}, nil) } if _, ok := r.(*IAVLNode); ok { right = r.(*IAVLNode) } else { - right = NewIAVLNode(Int32(r.(int)), nil) + right = NewIAVLNode([]byte{byte(r.(int))}, nil) } n := &IAVLNode{ @@ -43,7 +42,7 @@ func TestUnit(t *testing.T) { right: right, } n.calcHeightAndSize(nil) - n.Hash() + n.HashWithCount() return n } @@ -51,7 +50,7 @@ func TestUnit(t *testing.T) { var P func(*IAVLNode) string P = func(n *IAVLNode) string { if n.height == 0 { - return fmt.Sprintf("%v", n.key) + return fmt.Sprintf("%v", n.key[0]) } else { return fmt.Sprintf("(%v %v)", P(n.left), P(n.right)) } @@ -59,24 +58,24 @@ func TestUnit(t *testing.T) { expectHash := func(n2 *IAVLNode, hashCount uint64) { // ensure number of new hash calculations is as expected. - hash, count := n2.Hash() + hash, count := n2.HashWithCount() if count != hashCount { t.Fatalf("Expected %v new hashes, got %v", hashCount, count) } // nuke hashes and reconstruct hash, ensure it's the same. - (&IAVLTree{root: n2}).Traverse(func(node Node) bool { - node.(*IAVLNode).hash = nil + n2.traverse(nil, func(node *IAVLNode) bool { + node.hash = nil return false }) // ensure that the new hash after nuking is the same as the old. - newHash, _ := n2.Hash() + newHash, _ := n2.HashWithCount() if bytes.Compare(hash, newHash) != 0 { t.Fatalf("Expected hash %v but got %v after nuking", hash, newHash) } } expectSet := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, updated := n.set(nil, Int32(i), nil) + n2, updated := n.set(nil, []byte{byte(i)}, nil) // ensure node was added & structure is as expected. if updated == true || P(n2) != repr { t.Fatalf("Adding %v to %v:\nExpected %v\nUnexpectedly got %v updated:%v", @@ -87,7 +86,7 @@ func TestUnit(t *testing.T) { } expectRemove := func(n *IAVLNode, i int, repr string, hashCount uint64) { - n2, _, value, err := n.remove(nil, Int32(i)) + n2, _, value, err := n.remove(nil, []byte{byte(i)}) // ensure node was added & structure is as expected. if value != nil || err != nil || P(n2) != repr { t.Fatalf("Removing %v from %v:\nExpected %v\nUnexpectedly got %v value:%v err:%v", @@ -137,14 +136,14 @@ func TestUnit(t *testing.T) { func TestIntegration(t *testing.T) { type record struct { - key String - value String + key string + value string } records := make([]*record, 400) var tree *IAVLTree = NewIAVLTree(nil) var err error - var val Value + var val []byte var updated bool randomRecord := func() *record { @@ -156,11 +155,11 @@ func TestIntegration(t *testing.T) { records[i] = r //t.Log("New record", r) //PrintIAVLNode(tree.root) - updated = tree.Set(r.key, String("")) + updated = tree.Set([]byte(r.key), []byte("")) if updated { t.Error("should have not been updated") } - updated = tree.Set(r.key, r.value) + updated = tree.Set([]byte(r.key), []byte(r.value)) if !updated { t.Error("should have been updated") } @@ -170,31 +169,32 @@ func TestIntegration(t *testing.T) { } for _, r := range records { - if has := tree.Has(r.key); !has { + if has := tree.Has([]byte(r.key)); !has { t.Error("Missing key", r.key) } - if has := tree.Has(randstr(12)); has { + if has := tree.Has([]byte(randstr(12))); has { t.Error("Table has extra key") } - if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + if val := tree.Get([]byte(r.key)); string(val) != r.value { t.Error("wrong value") } } for i, x := range records { - if val, err = tree.Remove(x.key); err != nil { + if val, err = tree.Remove([]byte(x.key)); err != nil { t.Error(err) - } else if !(val.(String)).Equals(x.value) { + } else if string(val) != x.value { t.Error("wrong value") } for _, r := range records[i+1:] { - if has := tree.Has(r.key); !has { + if has := tree.Has([]byte(r.key)); !has { t.Error("Missing key", r.key) } - if has := tree.Has(randstr(12)); has { + if has := tree.Has([]byte(randstr(12))); has { t.Error("Table has extra key") } - if val := tree.Get(r.key); !(val.(String)).Equals(r.value) { + val := tree.Get([]byte(r.key)) + if string(val) != r.value { t.Error("wrong value") } } @@ -208,25 +208,25 @@ func TestPersistence(t *testing.T) { db := db.NewMemDB() // Create some random key value pairs - records := make(map[String]String) + records := make(map[string]string) for i := 0; i < 10000; i++ { - records[String(randstr(20))] = String(randstr(20)) + records[randstr(20)] = randstr(20) } // Construct some tree and save it t1 := NewIAVLTree(db) for key, value := range records { - t1.Set(key, value) + t1.Set([]byte(key), []byte(value)) } t1.Save() - hash, _ := t1.Hash() + hash, _ := t1.HashWithCount() // Load a tree t2 := NewIAVLTreeFromHash(db, hash) for key, value := range records { - t2value := t2.Get(key) - if !BinaryEqual(t2value, value) { + t2value := t2.Get([]byte(key)) + if string(t2value) != value { t.Fatalf("Invalid value. Expected %v, got %v", value, t2value) } } @@ -249,8 +249,8 @@ func BenchmarkImmutableAvlTree(b *testing.B) { b.StopTimer() type record struct { - key String - value String + key string + value string } randomRecord := func() *record { @@ -260,7 +260,7 @@ func BenchmarkImmutableAvlTree(b *testing.B) { t := NewIAVLTree(nil) for i := 0; i < 1000000; i++ { r := randomRecord() - t.Set(r.key, r.value) + t.Set([]byte(r.key), []byte(r.value)) } fmt.Println("ok, starting") @@ -270,7 +270,7 @@ func BenchmarkImmutableAvlTree(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { r := randomRecord() - t.Set(r.key, r.value) - t.Remove(r.key) + t.Set([]byte(r.key), []byte(r.value)) + t.Remove([]byte(r.key)) } } diff --git a/merkle/iavl_tree.go b/merkle/iavl_tree.go index 286f848eb..d35b64c84 100644 --- a/merkle/iavl_tree.go +++ b/merkle/iavl_tree.go @@ -1,9 +1,5 @@ package merkle -import ( - . "github.com/tendermint/tendermint/binary" -) - const HASH_BYTE_SIZE int = 4 + 32 /* @@ -18,10 +14,14 @@ type IAVLTree struct { } func NewIAVLTree(db Db) *IAVLTree { - return &IAVLTree{db: db, root: nil} + return &IAVLTree{ + db: db, + root: nil, + } } -func NewIAVLTreeFromHash(db Db, hash ByteSlice) *IAVLTree { +// TODO rename to Load. +func NewIAVLTreeFromHash(db Db, hash []byte) *IAVLTree { root := &IAVLNode{ hash: hash, flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER, @@ -43,10 +43,6 @@ func NewIAVLTreeFromKey(db Db, key string) *IAVLTree { return &IAVLTree{db: db, root: root} } -func (t *IAVLTree) Root() Node { - return t.root -} - func (t *IAVLTree) Size() uint64 { if t.root == nil { return 0 @@ -61,14 +57,14 @@ func (t *IAVLTree) Height() uint8 { return t.root.Height() } -func (t *IAVLTree) Has(key Key) bool { +func (t *IAVLTree) Has(key []byte) bool { if t.root == nil { return false } return t.root.has(t.db, key) } -func (t *IAVLTree) Set(key Key, value Value) (updated bool) { +func (t *IAVLTree) Set(key []byte, value []byte) (updated bool) { if t.root == nil { t.root = NewIAVLNode(key, value) return false @@ -77,18 +73,26 @@ func (t *IAVLTree) Set(key Key, value Value) (updated bool) { return updated } -func (t *IAVLTree) Hash() (ByteSlice, uint64) { +func (t *IAVLTree) Hash() []byte { + if t.root == nil { + return nil + } + hash, _ := t.root.HashWithCount() + return hash +} + +func (t *IAVLTree) HashWithCount() ([]byte, uint64) { if t.root == nil { return nil, 0 } - return t.root.Hash() + return t.root.HashWithCount() } func (t *IAVLTree) Save() { if t.root == nil { return } - t.root.Hash() + t.root.HashWithCount() t.root.Save(t.db) } @@ -96,19 +100,19 @@ func (t *IAVLTree) SaveKey(key string) { if t.root == nil { return } - hash, _ := t.root.Hash() + hash, _ := t.root.HashWithCount() t.root.Save(t.db) t.db.Set([]byte(key), hash) } -func (t *IAVLTree) Get(key Key) (value Value) { +func (t *IAVLTree) Get(key []byte) (value []byte) { if t.root == nil { return nil } return t.root.get(t.db, key) } -func (t *IAVLTree) Remove(key Key) (value Value, err error) { +func (t *IAVLTree) Remove(key []byte) (value []byte, err error) { if t.root == nil { return nil, NotFound(key) } @@ -123,32 +127,3 @@ func (t *IAVLTree) Remove(key Key) (value Value, err error) { func (t *IAVLTree) Copy() Tree { return &IAVLTree{db: t.db, root: t.root} } - -// Traverses all the nodes of the tree in prefix order. -// return true from cb to halt iteration. -// node.Height() == 0 if you just want a value node. -func (t *IAVLTree) Traverse(cb func(Node) bool) { - if t.root == nil { - return - } - t.root.traverse(t.db, cb) -} - -func (t *IAVLTree) Values() <-chan Value { - root := t.root - ch := make(chan Value) - if root == nil { - close(ch) - return ch - } - go func() { - root.traverse(t.db, func(n Node) bool { - if n.Height() == 0 { - ch <- n.Value() - } - return true - }) - close(ch) - }() - return ch -} diff --git a/merkle/types.go b/merkle/types.go index 96fd3beda..fe432be37 100644 --- a/merkle/types.go +++ b/merkle/types.go @@ -2,50 +2,27 @@ package merkle import ( "fmt" - . "github.com/tendermint/tendermint/binary" ) -type Value interface { - Binary -} - -type Key interface { - Binary - Equals(interface{}) bool - Less(b interface{}) bool -} - type Db interface { Get([]byte) []byte Set([]byte, []byte) } -type Node interface { - Binary - Key() Key - Value() Value - Size() uint64 - Height() uint8 - Hash() (ByteSlice, uint64) - Save(Db) -} - type Tree interface { - Root() Node Size() uint64 Height() uint8 - Has(key Key) bool - Get(key Key) Value - Hash() (ByteSlice, uint64) + Has(key []byte) bool + Get(key []byte) []byte + HashWithCount() ([]byte, uint64) + Hash() []byte Save() SaveKey(string) - Set(Key, Value) bool - Remove(Key) (Value, error) + Set(key []byte, vlaue []byte) bool + Remove(key []byte) ([]byte, error) Copy() Tree - Traverse(func(Node) bool) - Values() <-chan Value } -func NotFound(key Key) error { +func NotFound(key []byte) error { return fmt.Errorf("Key was not found.") } diff --git a/merkle/util.go b/merkle/util.go index 578b406f4..8b538173d 100644 --- a/merkle/util.go +++ b/merkle/util.go @@ -7,10 +7,15 @@ import ( . "github.com/tendermint/tendermint/binary" ) +func HashFromByteSlices(items [][]byte) []byte { + panic("Implement me") + return nil +} + /* Compute a deterministic merkle hash from a list of byteslices. */ -func HashFromBinarySlice(items []Binary) ByteSlice { +func HashFromBinarySlice(items []Binary) []byte { switch len(items) { case 0: panic("Cannot compute hash of empty slice") @@ -20,18 +25,22 @@ func HashFromBinarySlice(items []Binary) ByteSlice { if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) + return hasher.Sum(nil) default: - hasher := sha256.New() - _, err := HashFromBinarySlice(items[0 : len(items)/2]).WriteTo(hasher) + var n int64 + var err error + var hasher = sha256.New() + hash := HashFromBinarySlice(items[0 : len(items)/2]) + WriteByteSlice(hasher, hash, &n, &err) if err != nil { panic(err) } - _, err = HashFromBinarySlice(items[len(items)/2:]).WriteTo(hasher) + hash = HashFromBinarySlice(items[len(items)/2:]) + WriteByteSlice(hasher, hash, &n, &err) if err != nil { panic(err) } - return ByteSlice(hasher.Sum(nil)) + return hasher.Sum(nil) } } diff --git a/state/store.go b/state/store.go new file mode 100644 index 000000000..8133fc890 --- /dev/null +++ b/state/store.go @@ -0,0 +1,19 @@ +package state + +import ( + . "github.com/tendermint/tendermint/blocks" +) + +// XXX ugh, bad name. +type StateStore struct { +} + +func (ss *StateStore) StageBlock(block *Block) error { + // XXX implement staging. + return nil +} + +func (ss *StateStore) CommitBlock(block *Block) error { + // XXX implement staging. + return nil +} diff --git a/state/validator.go b/state/validator.go new file mode 100644 index 000000000..884b0836c --- /dev/null +++ b/state/validator.go @@ -0,0 +1,141 @@ +package state + +import ( + "io" + + . "github.com/tendermint/tendermint/binary" + . "github.com/tendermint/tendermint/blocks" + //. "github.com/tendermint/tendermint/common" + db_ "github.com/tendermint/tendermint/db" +) + +// Holds state for a Validator at a given height+round. +// Meant to be discarded every round of the consensus protocol. +// TODO consider moving this to another common types package. +type Validator struct { + Account + BondHeight uint32 + VotingPower uint64 + Accum int64 +} + +// Used to persist the state of ConsensusStateControl. +func ReadValidator(r io.Reader) *Validator { + return &Validator{ + Account: Account{ + Id: Readuint64(r), + PubKey: ReadByteSlice(r), + }, + BondHeight: Readuint32(r), + VotingPower: Readuint64(r), + Accum: Readint64(r), + } +} + +// Creates a new copy of the validator so we can mutate accum. +func (v *Validator) Copy() *Validator { + return &Validator{ + Account: v.Account, + BondHeight: v.BondHeight, + VotingPower: v.VotingPower, + Accum: v.Accum, + } +} + +// Used to persist the state of ConsensusStateControl. +func (v *Validator) WriteTo(w io.Writer) (n int64, err error) { + n, err = WriteTo(UInt64(v.Id), w, n, err) + n, err = WriteTo(v.PubKey, w, n, err) + n, err = WriteTo(UInt32(v.BondHeight), w, n, err) + n, err = WriteTo(UInt64(v.VotingPower), w, n, err) + n, err = WriteTo(Int64(v.Accum), w, n, err) + return +} + +//----------------------------------------------------------------------------- + +// TODO: Ensure that double signing never happens via an external persistent check. +type PrivValidator struct { + PrivAccount + db *db_.LevelDB +} + +// Modifies the vote object in memory. +// Double signing results in an error. +func (pv *PrivValidator) SignVote(vote *Vote) error { + return nil +} + +//----------------------------------------------------------------------------- + +// Not goroutine-safe. +type ValidatorSet struct { + validators map[uint64]*Validator +} + +func NewValidatorSet(validators map[uint64]*Validator) *ValidatorSet { + if validators == nil { + validators = make(map[uint64]*Validator) + } + return &ValidatorSet{ + valdiators: validators, + } +} + +func (v *ValidatorSet) IncrementAccum() { + totalDelta := int64(0) + for _, validator := range v.validators { + validator.Accum += int64(validator.VotingPower) + totalDelta += int64(validator.VotingPower) + } + proposer := v.GetProposer() + proposer.Accum -= totalDelta + // NOTE: sum(v) here should be zero. + if true { + totalAccum := int64(0) + for _, validator := range v.validators { + totalAccum += validator.Accum + } + if totalAccum != 0 { + Panicf("Total Accum of validators did not equal 0. Got: ", totalAccum) + } + } +} + +func (v *ValidatorSet) Copy() *ValidatorSet { + mapCopy := map[uint64]*Validator{} + for _, val := range validators { + mapCopy[val.Id] = val.Copy() + } + return &ValidatorSet{ + validators: mapCopy, + } +} + +func (v *ValidatorSet) Add(validator *Valdaitor) { + v.validators[validator.Id] = validator +} + +func (v *ValidatorSet) Get(id uint64) *Validator { + return v.validators[validator.Id] +} + +func (v *ValidatorSet) Map() map[uint64]*Validator { + return v.validators +} + +// TODO: cache proposer. invalidate upon increment. +func (v *ValidatorSet) GetProposer() (proposer *Validator) { + highestAccum := int64(0) + for _, validator := range v.validators { + if validator.Accum > highestAccum { + highestAccum = validator.Accum + proposer = validator + } else if validator.Accum == highestAccum { + if validator.Id < proposer.Id { // Seniority + proposer = validator + } + } + } + return +}