You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
3.5 KiB

  1. //nolint: gosec
  2. package app
  3. import (
  4. "crypto/sha256"
  5. "encoding/json"
  6. "errors"
  7. "fmt"
  8. "os"
  9. "sort"
  10. "sync"
  11. )
  12. // State is the application state.
  13. type State struct {
  14. sync.RWMutex
  15. Height uint64
  16. Values map[string]string
  17. Hash []byte
  18. // private fields aren't marshaled to disk.
  19. file string
  20. persistInterval uint64
  21. initialHeight uint64
  22. }
  23. // NewState creates a new state.
  24. func NewState(file string, persistInterval uint64) (*State, error) {
  25. state := &State{
  26. Values: make(map[string]string),
  27. file: file,
  28. persistInterval: persistInterval,
  29. }
  30. state.Hash = hashItems(state.Values)
  31. err := state.load()
  32. switch {
  33. case errors.Is(err, os.ErrNotExist):
  34. case err != nil:
  35. return nil, err
  36. }
  37. return state, nil
  38. }
  39. // load loads state from disk. It does not take out a lock, since it is called
  40. // during construction.
  41. func (s *State) load() error {
  42. bz, err := os.ReadFile(s.file)
  43. if err != nil {
  44. return fmt.Errorf("failed to read state from %q: %w", s.file, err)
  45. }
  46. err = json.Unmarshal(bz, s)
  47. if err != nil {
  48. return fmt.Errorf("invalid state data in %q: %w", s.file, err)
  49. }
  50. return nil
  51. }
  52. // save saves the state to disk. It does not take out a lock since it is called
  53. // internally by Commit which does lock.
  54. func (s *State) save() error {
  55. bz, err := json.Marshal(s)
  56. if err != nil {
  57. return fmt.Errorf("failed to marshal state: %w", err)
  58. }
  59. // We write the state to a separate file and move it to the destination, to
  60. // make it atomic.
  61. newFile := fmt.Sprintf("%v.new", s.file)
  62. err = os.WriteFile(newFile, bz, 0644)
  63. if err != nil {
  64. return fmt.Errorf("failed to write state to %q: %w", s.file, err)
  65. }
  66. return os.Rename(newFile, s.file)
  67. }
  68. // Export exports key/value pairs as JSON, used for state sync snapshots.
  69. func (s *State) Export() ([]byte, error) {
  70. s.RLock()
  71. defer s.RUnlock()
  72. return json.Marshal(s.Values)
  73. }
  74. // Import imports key/value pairs from JSON bytes, used for InitChain.AppStateBytes and
  75. // state sync snapshots. It also saves the state once imported.
  76. func (s *State) Import(height uint64, jsonBytes []byte) error {
  77. s.Lock()
  78. defer s.Unlock()
  79. values := map[string]string{}
  80. err := json.Unmarshal(jsonBytes, &values)
  81. if err != nil {
  82. return fmt.Errorf("failed to decode imported JSON data: %w", err)
  83. }
  84. s.Height = height
  85. s.Values = values
  86. s.Hash = hashItems(values)
  87. return s.save()
  88. }
  89. // Get fetches a value. A missing value is returned as an empty string.
  90. func (s *State) Get(key string) string {
  91. s.RLock()
  92. defer s.RUnlock()
  93. return s.Values[key]
  94. }
  95. // Set sets a value. Setting an empty value is equivalent to deleting it.
  96. func (s *State) Set(key, value string) {
  97. s.Lock()
  98. defer s.Unlock()
  99. if value == "" {
  100. delete(s.Values, key)
  101. } else {
  102. s.Values[key] = value
  103. }
  104. }
  105. // Commit commits the current state.
  106. func (s *State) Commit() (uint64, []byte, error) {
  107. s.Lock()
  108. defer s.Unlock()
  109. s.Hash = hashItems(s.Values)
  110. switch {
  111. case s.Height > 0:
  112. s.Height++
  113. case s.initialHeight > 0:
  114. s.Height = s.initialHeight
  115. default:
  116. s.Height = 1
  117. }
  118. if s.persistInterval > 0 && s.Height%s.persistInterval == 0 {
  119. err := s.save()
  120. if err != nil {
  121. return 0, nil, err
  122. }
  123. }
  124. return s.Height, s.Hash, nil
  125. }
  126. // hashItems hashes a set of key/value items.
  127. func hashItems(items map[string]string) []byte {
  128. keys := make([]string, 0, len(items))
  129. for key := range items {
  130. keys = append(keys, key)
  131. }
  132. sort.Strings(keys)
  133. hasher := sha256.New()
  134. for _, key := range keys {
  135. _, _ = hasher.Write([]byte(key))
  136. _, _ = hasher.Write([]byte{0})
  137. _, _ = hasher.Write([]byte(items[key]))
  138. _, _ = hasher.Write([]byte{0})
  139. }
  140. return hasher.Sum(nil)
  141. }