You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
3.6 KiB

  1. //nolint: gosec
  2. package main
  3. import (
  4. "crypto/sha256"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "io/ioutil"
  9. "os"
  10. "sort"
  11. "sync"
  12. )
  13. // State is the application state.
  14. type State struct {
  15. sync.RWMutex
  16. Height uint64
  17. Values map[string]string
  18. Hash []byte
  19. // private fields aren't marshalled to disk.
  20. file string
  21. persistInterval uint64
  22. initialHeight uint64
  23. }
  24. // NewState creates a new state.
  25. func NewState(file string, persistInterval uint64) (*State, error) {
  26. state := &State{
  27. Values: make(map[string]string),
  28. file: file,
  29. persistInterval: persistInterval,
  30. }
  31. state.Hash = hashItems(state.Values)
  32. err := state.load()
  33. switch {
  34. case errors.Is(err, os.ErrNotExist):
  35. case err != nil:
  36. return nil, err
  37. }
  38. return state, nil
  39. }
  40. // load loads state from disk. It does not take out a lock, since it is called
  41. // during construction.
  42. func (s *State) load() error {
  43. bz, err := ioutil.ReadFile(s.file)
  44. if err != nil {
  45. return fmt.Errorf("failed to read state from %q: %w", s.file, err)
  46. }
  47. err = json.Unmarshal(bz, s)
  48. if err != nil {
  49. return fmt.Errorf("invalid state data in %q: %w", s.file, err)
  50. }
  51. return nil
  52. }
  53. // save saves the state to disk. It does not take out a lock since it is called
  54. // internally by Commit which does lock.
  55. func (s *State) save() error {
  56. bz, err := json.Marshal(s)
  57. if err != nil {
  58. return fmt.Errorf("failed to marshal state: %w", err)
  59. }
  60. // We write the state to a separate file and move it to the destination, to
  61. // make it atomic.
  62. newFile := fmt.Sprintf("%v.new", s.file)
  63. err = ioutil.WriteFile(newFile, bz, 0644)
  64. if err != nil {
  65. return fmt.Errorf("failed to write state to %q: %w", s.file, err)
  66. }
  67. return os.Rename(newFile, s.file)
  68. }
  69. // Export exports key/value pairs as JSON, used for state sync snapshots.
  70. func (s *State) Export() ([]byte, error) {
  71. s.RLock()
  72. defer s.RUnlock()
  73. return json.Marshal(s.Values)
  74. }
  75. // Import imports key/value pairs from JSON bytes, used for InitChain.AppStateBytes and
  76. // state sync snapshots. It also saves the state once imported.
  77. func (s *State) Import(height uint64, jsonBytes []byte) error {
  78. s.Lock()
  79. defer s.Unlock()
  80. values := map[string]string{}
  81. err := json.Unmarshal(jsonBytes, &values)
  82. if err != nil {
  83. return fmt.Errorf("failed to decode imported JSON data: %w", err)
  84. }
  85. s.Height = height
  86. s.Values = values
  87. s.Hash = hashItems(values)
  88. return s.save()
  89. }
  90. // Get fetches a value. A missing value is returned as an empty string.
  91. func (s *State) Get(key string) string {
  92. s.RLock()
  93. defer s.RUnlock()
  94. return s.Values[key]
  95. }
  96. // Set sets a value. Setting an empty value is equivalent to deleting it.
  97. func (s *State) Set(key, value string) {
  98. s.Lock()
  99. defer s.Unlock()
  100. if value == "" {
  101. delete(s.Values, key)
  102. } else {
  103. s.Values[key] = value
  104. }
  105. }
  106. // Commit commits the current state.
  107. func (s *State) Commit() (uint64, []byte, error) {
  108. s.Lock()
  109. defer s.Unlock()
  110. s.Hash = hashItems(s.Values)
  111. switch {
  112. case s.Height > 0:
  113. s.Height++
  114. case s.initialHeight > 0:
  115. s.Height = s.initialHeight
  116. default:
  117. s.Height = 1
  118. }
  119. if s.persistInterval > 0 && s.Height%s.persistInterval == 0 {
  120. err := s.save()
  121. if err != nil {
  122. return 0, nil, err
  123. }
  124. }
  125. return s.Height, s.Hash, nil
  126. }
  127. // hashItems hashes a set of key/value items.
  128. func hashItems(items map[string]string) []byte {
  129. keys := make([]string, 0, len(items))
  130. for key := range items {
  131. keys = append(keys, key)
  132. }
  133. sort.Strings(keys)
  134. hasher := sha256.New()
  135. for _, key := range keys {
  136. _, _ = hasher.Write([]byte(key))
  137. _, _ = hasher.Write([]byte{0})
  138. _, _ = hasher.Write([]byte(items[key]))
  139. _, _ = hasher.Write([]byte{0})
  140. }
  141. return hasher.Sum(nil)
  142. }