You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

138 lines
4.0 KiB

  1. package merkle
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. )
  7. //----------------------------------------
  8. // ProofOp gets converted to an instance of ProofOperator:
  9. // ProofOperator is a layer for calculating intermediate Merkle roots
  10. // when a series of Merkle trees are chained together.
  11. // Run() takes leaf values from a tree and returns the Merkle
  12. // root for the corresponding tree. It takes and returns a list of bytes
  13. // to allow multiple leaves to be part of a single proof, for instance in a range proof.
  14. // ProofOp() encodes the ProofOperator in a generic way so it can later be
  15. // decoded with OpDecoder.
  16. type ProofOperator interface {
  17. Run([][]byte) ([][]byte, error)
  18. GetKey() []byte
  19. ProofOp() ProofOp
  20. }
  21. //----------------------------------------
  22. // Operations on a list of ProofOperators
  23. // ProofOperators is a slice of ProofOperator(s).
  24. // Each operator will be applied to the input value sequentially
  25. // and the last Merkle root will be verified with already known data
  26. type ProofOperators []ProofOperator
  27. func (poz ProofOperators) VerifyValue(root []byte, keypath string, value []byte) (err error) {
  28. return poz.Verify(root, keypath, [][]byte{value})
  29. }
  30. func (poz ProofOperators) Verify(root []byte, keypath string, args [][]byte) (err error) {
  31. keys, err := KeyPathToKeys(keypath)
  32. if err != nil {
  33. return
  34. }
  35. for i, op := range poz {
  36. key := op.GetKey()
  37. if len(key) != 0 {
  38. if len(keys) == 0 {
  39. return fmt.Errorf("key path has insufficient # of parts: expected no more keys but got %+v", string(key))
  40. }
  41. lastKey := keys[len(keys)-1]
  42. if !bytes.Equal(lastKey, key) {
  43. return fmt.Errorf("key mismatch on operation #%d: expected %+v but got %+v", i, string(lastKey), string(key))
  44. }
  45. keys = keys[:len(keys)-1]
  46. }
  47. args, err = op.Run(args)
  48. if err != nil {
  49. return
  50. }
  51. }
  52. if !bytes.Equal(root, args[0]) {
  53. return fmt.Errorf("calculated root hash is invalid: expected %+v but got %+v", root, args[0])
  54. }
  55. if len(keys) != 0 {
  56. return errors.New("keypath not consumed all")
  57. }
  58. return nil
  59. }
  60. //----------------------------------------
  61. // ProofRuntime - main entrypoint
  62. type OpDecoder func(ProofOp) (ProofOperator, error)
  63. type ProofRuntime struct {
  64. decoders map[string]OpDecoder
  65. }
  66. func NewProofRuntime() *ProofRuntime {
  67. return &ProofRuntime{
  68. decoders: make(map[string]OpDecoder),
  69. }
  70. }
  71. func (prt *ProofRuntime) RegisterOpDecoder(typ string, dec OpDecoder) {
  72. _, ok := prt.decoders[typ]
  73. if ok {
  74. panic("already registered for type " + typ)
  75. }
  76. prt.decoders[typ] = dec
  77. }
  78. func (prt *ProofRuntime) Decode(pop ProofOp) (ProofOperator, error) {
  79. decoder := prt.decoders[pop.Type]
  80. if decoder == nil {
  81. return nil, fmt.Errorf("unrecognized proof type %v", pop.Type)
  82. }
  83. return decoder(pop)
  84. }
  85. func (prt *ProofRuntime) DecodeProof(proof *Proof) (ProofOperators, error) {
  86. poz := make(ProofOperators, 0, len(proof.Ops))
  87. for _, pop := range proof.Ops {
  88. operator, err := prt.Decode(pop)
  89. if err != nil {
  90. return nil, fmt.Errorf("decoding a proof operator: %w", err)
  91. }
  92. poz = append(poz, operator)
  93. }
  94. return poz, nil
  95. }
  96. func (prt *ProofRuntime) VerifyValue(proof *Proof, root []byte, keypath string, value []byte) (err error) {
  97. return prt.Verify(proof, root, keypath, [][]byte{value})
  98. }
  99. // TODO In the long run we'll need a method of classifcation of ops,
  100. // whether existence or absence or perhaps a third?
  101. func (prt *ProofRuntime) VerifyAbsence(proof *Proof, root []byte, keypath string) (err error) {
  102. return prt.Verify(proof, root, keypath, nil)
  103. }
  104. func (prt *ProofRuntime) Verify(proof *Proof, root []byte, keypath string, args [][]byte) (err error) {
  105. poz, err := prt.DecodeProof(proof)
  106. if err != nil {
  107. return fmt.Errorf("decoding proof: %w", err)
  108. }
  109. return poz.Verify(root, keypath, args)
  110. }
  111. // DefaultProofRuntime only knows about Simple value
  112. // proofs.
  113. // To use e.g. IAVL proofs, register op-decoders as
  114. // defined in the IAVL package.
  115. func DefaultProofRuntime() (prt *ProofRuntime) {
  116. prt = NewProofRuntime()
  117. prt.RegisterOpDecoder(ProofOpSimpleValue, SimpleValueOpDecoder)
  118. return
  119. }