You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
4.8 KiB

  1. //nolint: gosec
  2. package app
  3. import (
  4. "crypto/sha256"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "os"
  9. "path/filepath"
  10. "sort"
  11. "sync"
  12. )
  13. const stateFileName = "app_state.json"
  14. const prevStateFileName = "prev_app_state.json"
  15. // State is the application state.
  16. type State struct {
  17. sync.RWMutex
  18. Height uint64
  19. Values map[string]string
  20. Hash []byte
  21. // private fields aren't marshaled to disk.
  22. currentFile string
  23. // app saves current and previous state for rollback functionality
  24. previousFile string
  25. persistInterval uint64
  26. initialHeight uint64
  27. }
  28. // NewState creates a new state.
  29. func NewState(dir string, persistInterval uint64) (*State, error) {
  30. state := &State{
  31. Values: make(map[string]string),
  32. currentFile: filepath.Join(dir, stateFileName),
  33. previousFile: filepath.Join(dir, prevStateFileName),
  34. persistInterval: persistInterval,
  35. }
  36. state.Hash = hashItems(state.Values)
  37. err := state.load()
  38. switch {
  39. case errors.Is(err, os.ErrNotExist):
  40. case err != nil:
  41. return nil, err
  42. }
  43. return state, nil
  44. }
  45. // load loads state from disk. It does not take out a lock, since it is called
  46. // during construction.
  47. func (s *State) load() error {
  48. bz, err := os.ReadFile(s.currentFile)
  49. if err != nil {
  50. // if the current state doesn't exist then we try recover from the previous state
  51. if errors.Is(err, os.ErrNotExist) {
  52. bz, err = os.ReadFile(s.previousFile)
  53. if err != nil {
  54. return fmt.Errorf("failed to read both current and previous state (%q): %w",
  55. s.previousFile, err)
  56. }
  57. } else {
  58. return fmt.Errorf("failed to read state from %q: %w", s.currentFile, err)
  59. }
  60. }
  61. err = json.Unmarshal(bz, s)
  62. if err != nil {
  63. return fmt.Errorf("invalid state data in %q: %w", s.currentFile, err)
  64. }
  65. return nil
  66. }
  67. // save saves the state to disk. It does not take out a lock since it is called
  68. // internally by Commit which does lock.
  69. func (s *State) save() error {
  70. bz, err := json.Marshal(s)
  71. if err != nil {
  72. return fmt.Errorf("failed to marshal state: %w", err)
  73. }
  74. // We write the state to a separate file and move it to the destination, to
  75. // make it atomic.
  76. newFile := fmt.Sprintf("%v.new", s.currentFile)
  77. err = os.WriteFile(newFile, bz, 0644)
  78. if err != nil {
  79. return fmt.Errorf("failed to write state to %q: %w", s.currentFile, err)
  80. }
  81. // We take the current state and move it to the previous state, replacing it
  82. if _, err := os.Stat(s.currentFile); err == nil {
  83. if err := os.Rename(s.currentFile, s.previousFile); err != nil {
  84. return fmt.Errorf("failed to replace previous state: %w", err)
  85. }
  86. }
  87. // Finally, we take the new state and replace the current state.
  88. return os.Rename(newFile, s.currentFile)
  89. }
  90. // Export exports key/value pairs as JSON, used for state sync snapshots.
  91. func (s *State) Export() ([]byte, error) {
  92. s.RLock()
  93. defer s.RUnlock()
  94. return json.Marshal(s.Values)
  95. }
  96. // Import imports key/value pairs from JSON bytes, used for InitChain.AppStateBytes and
  97. // state sync snapshots. It also saves the state once imported.
  98. func (s *State) Import(height uint64, jsonBytes []byte) error {
  99. s.Lock()
  100. defer s.Unlock()
  101. values := map[string]string{}
  102. err := json.Unmarshal(jsonBytes, &values)
  103. if err != nil {
  104. return fmt.Errorf("failed to decode imported JSON data: %w", err)
  105. }
  106. s.Height = height
  107. s.Values = values
  108. s.Hash = hashItems(values)
  109. return s.save()
  110. }
  111. // Get fetches a value. A missing value is returned as an empty string.
  112. func (s *State) Get(key string) string {
  113. s.RLock()
  114. defer s.RUnlock()
  115. return s.Values[key]
  116. }
  117. // Set sets a value. Setting an empty value is equivalent to deleting it.
  118. func (s *State) Set(key, value string) {
  119. s.Lock()
  120. defer s.Unlock()
  121. if value == "" {
  122. delete(s.Values, key)
  123. } else {
  124. s.Values[key] = value
  125. }
  126. }
  127. // Commit commits the current state.
  128. func (s *State) Commit() (uint64, []byte, error) {
  129. s.Lock()
  130. defer s.Unlock()
  131. s.Hash = hashItems(s.Values)
  132. switch {
  133. case s.Height > 0:
  134. s.Height++
  135. case s.initialHeight > 0:
  136. s.Height = s.initialHeight
  137. default:
  138. s.Height = 1
  139. }
  140. if s.persistInterval > 0 && s.Height%s.persistInterval == 0 {
  141. err := s.save()
  142. if err != nil {
  143. return 0, nil, err
  144. }
  145. }
  146. return s.Height, s.Hash, nil
  147. }
  148. func (s *State) Rollback() error {
  149. bz, err := os.ReadFile(s.previousFile)
  150. if err != nil {
  151. return fmt.Errorf("failed to read state from %q: %w", s.previousFile, err)
  152. }
  153. err = json.Unmarshal(bz, s)
  154. if err != nil {
  155. return fmt.Errorf("invalid state data in %q: %w", s.previousFile, err)
  156. }
  157. return nil
  158. }
  159. // hashItems hashes a set of key/value items.
  160. func hashItems(items map[string]string) []byte {
  161. keys := make([]string, 0, len(items))
  162. for key := range items {
  163. keys = append(keys, key)
  164. }
  165. sort.Strings(keys)
  166. hasher := sha256.New()
  167. for _, key := range keys {
  168. _, _ = hasher.Write([]byte(key))
  169. _, _ = hasher.Write([]byte{0})
  170. _, _ = hasher.Write([]byte(items[key]))
  171. _, _ = hasher.Write([]byte{0})
  172. }
  173. return hasher.Sum(nil)
  174. }