|
|
- package merkle
-
- import (
- "bytes"
- "crypto/sha256"
- . "github.com/tendermint/tendermint/binary"
- "io"
- )
-
- // Node
-
- type IAVLNode struct {
- key []byte
- value []byte
- size uint64
- height uint8
- hash []byte
- left *IAVLNode
- right *IAVLNode
-
- // volatile
- flags byte
- }
-
- const (
- IAVLNODE_FLAG_PERSISTED = byte(0x01)
- IAVLNODE_FLAG_PLACEHOLDER = byte(0x02)
- )
-
- func NewIAVLNode(key []byte, value []byte) *IAVLNode {
- return &IAVLNode{
- key: key,
- value: value,
- size: 1,
- }
- }
-
- func (self *IAVLNode) Copy() *IAVLNode {
- if self.height == 0 {
- panic("Why are you copying a value node?")
- }
- return &IAVLNode{
- key: self.key,
- size: self.size,
- height: self.height,
- left: self.left,
- right: self.right,
- hash: nil,
- flags: byte(0),
- }
- }
-
- func (self *IAVLNode) Size() uint64 {
- return self.size
- }
-
- func (self *IAVLNode) Height() uint8 {
- return self.height
- }
-
- func (self *IAVLNode) has(db Db, key []byte) (has bool) {
- if bytes.Equal(self.key, key) {
- return true
- }
- if self.height == 0 {
- return false
- } else {
- if bytes.Compare(key, self.key) == -1 {
- return self.leftFilled(db).has(db, key)
- } else {
- return self.rightFilled(db).has(db, key)
- }
- }
- }
-
- func (self *IAVLNode) get(db Db, key []byte) (value []byte) {
- if self.height == 0 {
- if bytes.Equal(self.key, key) {
- return self.value
- } else {
- return nil
- }
- } else {
- if bytes.Compare(key, self.key) == -1 {
- return self.leftFilled(db).get(db, key)
- } else {
- return self.rightFilled(db).get(db, key)
- }
- }
- }
-
- func (self *IAVLNode) HashWithCount() ([]byte, uint64) {
- if self.hash != nil {
- return self.hash, 0
- }
-
- hasher := sha256.New()
- _, hashCount, err := self.saveToCountHashes(hasher)
- if err != nil {
- panic(err)
- }
- self.hash = hasher.Sum(nil)
-
- return self.hash, hashCount + 1
- }
-
- func (self *IAVLNode) Save(db Db) {
- if self.hash == nil {
- panic("savee.hash can't be nil")
- }
- if self.flags&IAVLNODE_FLAG_PERSISTED > 0 ||
- self.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
- return
- }
-
- // children
- if self.height > 0 {
- self.left.Save(db)
- self.right.Save(db)
- }
-
- // save self
- buf := bytes.NewBuffer(nil)
- _, err := self.WriteTo(buf)
- if err != nil {
- panic(err)
- }
- db.Set([]byte(self.hash), buf.Bytes())
-
- self.flags |= IAVLNODE_FLAG_PERSISTED
- }
-
- func (self *IAVLNode) set(db Db, key []byte, value []byte) (_ *IAVLNode, updated bool) {
- if self.height == 0 {
- if bytes.Compare(key, self.key) == -1 {
- return &IAVLNode{
- key: self.key,
- height: 1,
- size: 2,
- left: NewIAVLNode(key, value),
- right: self,
- }, false
- } else if bytes.Equal(self.key, key) {
- return NewIAVLNode(key, value), true
- } else {
- return &IAVLNode{
- key: key,
- height: 1,
- size: 2,
- left: self,
- right: NewIAVLNode(key, value),
- }, false
- }
- } else {
- self = self.Copy()
- if bytes.Compare(key, self.key) == -1 {
- self.left, updated = self.leftFilled(db).set(db, key, value)
- } else {
- self.right, updated = self.rightFilled(db).set(db, key, value)
- }
- if updated {
- return self, updated
- } else {
- self.calcHeightAndSize(db)
- return self.balance(db), updated
- }
- }
- }
-
- // newKey: new leftmost leaf key for tree after successfully removing 'key' if changed.
- func (self *IAVLNode) remove(db Db, key []byte) (newSelf *IAVLNode, newKey []byte, value []byte, err error) {
- if self.height == 0 {
- if bytes.Equal(self.key, key) {
- return nil, nil, self.value, nil
- } else {
- return self, nil, nil, NotFound(key)
- }
- } else {
- if bytes.Compare(key, self.key) == -1 {
- var newLeft *IAVLNode
- newLeft, newKey, value, err = self.leftFilled(db).remove(db, key)
- if err != nil {
- return self, nil, value, err
- } else if newLeft == nil { // left node held value, was removed
- return self.right, self.key, value, nil
- }
- self = self.Copy()
- self.left = newLeft
- } else {
- var newRight *IAVLNode
- newRight, newKey, value, err = self.rightFilled(db).remove(db, key)
- if err != nil {
- return self, nil, value, err
- } else if newRight == nil { // right node held value, was removed
- return self.left, nil, value, nil
- }
- self = self.Copy()
- self.right = newRight
- if newKey != nil {
- self.key = newKey
- newKey = nil
- }
- }
- self.calcHeightAndSize(db)
- return self.balance(db), newKey, value, err
- }
- }
-
- func (self *IAVLNode) WriteTo(w io.Writer) (n int64, err error) {
- n, _, err = self.saveToCountHashes(w)
- return
- }
-
- func (self *IAVLNode) saveToCountHashes(w io.Writer) (n int64, hashCount uint64, err error) {
- // height & size & key
- WriteUInt8(w, self.height, &n, &err)
- WriteUInt64(w, self.size, &n, &err)
- WriteByteSlice(w, self.key, &n, &err)
- if err != nil {
- return
- }
-
- // value or children
- if self.height == 0 {
- // value
- WriteByteSlice(w, self.value, &n, &err)
- } else {
- // left
- leftHash, leftCount := self.left.HashWithCount()
- hashCount += leftCount
- WriteByteSlice(w, leftHash, &n, &err)
- // right
- rightHash, rightCount := self.right.HashWithCount()
- hashCount += rightCount
- WriteByteSlice(w, rightHash, &n, &err)
- }
- return
- }
-
- // Given a placeholder node which has only the hash set,
- // load the rest of the data from db.
- // Not threadsafe.
- func (self *IAVLNode) fill(db Db) {
- if self.hash == nil {
- panic("placeholder.hash can't be nil")
- }
- buf := db.Get(self.hash)
- r := bytes.NewReader(buf)
- var n int64
- var err error
-
- // node header & key
- self.height = ReadUInt8(r, &n, &err)
- self.size = ReadUInt64(r, &n, &err)
- self.key = ReadByteSlice(r, &n, &err)
- if err != nil {
- panic(err)
- }
-
- // node value or children.
- if self.height == 0 {
- // value
- self.value = ReadByteSlice(r, &n, &err)
- } else {
- // left
- leftHash := ReadByteSlice(r, &n, &err)
- self.left = &IAVLNode{
- hash: leftHash,
- flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
- }
- // right
- rightHash := ReadByteSlice(r, &n, &err)
- self.right = &IAVLNode{
- hash: rightHash,
- flags: IAVLNODE_FLAG_PERSISTED | IAVLNODE_FLAG_PLACEHOLDER,
- }
- if r.Len() != 0 {
- panic("buf not all consumed")
- }
- }
- if err != nil {
- panic(err)
- }
- self.flags &= ^IAVLNODE_FLAG_PLACEHOLDER
- }
-
- func (self *IAVLNode) leftFilled(db Db) *IAVLNode {
- if self.left.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
- self.left.fill(db)
- }
- return self.left
- }
-
- func (self *IAVLNode) rightFilled(db Db) *IAVLNode {
- if self.right.flags&IAVLNODE_FLAG_PLACEHOLDER > 0 {
- self.right.fill(db)
- }
- return self.right
- }
-
- func (self *IAVLNode) rotateRight(db Db) *IAVLNode {
- self = self.Copy()
- sl := self.leftFilled(db).Copy()
- slr := sl.right
-
- sl.right = self
- self.left = slr
-
- self.calcHeightAndSize(db)
- sl.calcHeightAndSize(db)
-
- return sl
- }
-
- func (self *IAVLNode) rotateLeft(db Db) *IAVLNode {
- self = self.Copy()
- sr := self.rightFilled(db).Copy()
- srl := sr.left
-
- sr.left = self
- self.right = srl
-
- self.calcHeightAndSize(db)
- sr.calcHeightAndSize(db)
-
- return sr
- }
-
- func (self *IAVLNode) calcHeightAndSize(db Db) {
- self.height = maxUint8(self.leftFilled(db).Height(), self.rightFilled(db).Height()) + 1
- self.size = self.leftFilled(db).Size() + self.rightFilled(db).Size()
- }
-
- func (self *IAVLNode) calcBalance(db Db) int {
- return int(self.leftFilled(db).Height()) - int(self.rightFilled(db).Height())
- }
-
- func (self *IAVLNode) balance(db Db) (newSelf *IAVLNode) {
- balance := self.calcBalance(db)
- if balance > 1 {
- if self.leftFilled(db).calcBalance(db) >= 0 {
- // Left Left Case
- return self.rotateRight(db)
- } else {
- // Left Right Case
- self = self.Copy()
- self.left = self.leftFilled(db).rotateLeft(db)
- //self.calcHeightAndSize()
- return self.rotateRight(db)
- }
- }
- if balance < -1 {
- if self.rightFilled(db).calcBalance(db) <= 0 {
- // Right Right Case
- return self.rotateLeft(db)
- } else {
- // Right Left Case
- self = self.Copy()
- self.right = self.rightFilled(db).rotateRight(db)
- //self.calcHeightAndSize()
- return self.rotateLeft(db)
- }
- }
- // Nothing changed
- return self
- }
-
- func (self *IAVLNode) lmd(db Db) *IAVLNode {
- if self.height == 0 {
- return self
- }
- return self.leftFilled(db).lmd(db)
- }
-
- func (self *IAVLNode) rmd(db Db) *IAVLNode {
- if self.height == 0 {
- return self
- }
- return self.rightFilled(db).rmd(db)
- }
-
- func (self *IAVLNode) traverse(db Db, cb func(*IAVLNode) bool) bool {
- stop := cb(self)
- if stop {
- return stop
- }
- if self.height > 0 {
- stop = self.leftFilled(db).traverse(db, cb)
- if stop {
- return stop
- }
- stop = self.rightFilled(db).traverse(db, cb)
- if stop {
- return stop
- }
- }
- return false
- }
|