You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
3.8 KiB

  1. package db
  2. import (
  3. "fmt"
  4. "sync"
  5. "testing"
  6. "github.com/stretchr/testify/assert"
  7. "github.com/stretchr/testify/require"
  8. cmn "github.com/tendermint/tmlibs/common"
  9. )
  10. //----------------------------------------
  11. // Helper functions.
  12. func checkValue(t *testing.T, db DB, key []byte, valueWanted []byte) {
  13. valueGot := db.Get(key)
  14. assert.Equal(t, valueWanted, valueGot)
  15. }
  16. func checkValid(t *testing.T, itr Iterator, expected bool) {
  17. valid := itr.Valid()
  18. require.Equal(t, expected, valid)
  19. }
  20. func checkNext(t *testing.T, itr Iterator, expected bool) {
  21. itr.Next()
  22. valid := itr.Valid()
  23. require.Equal(t, expected, valid)
  24. }
  25. func checkNextPanics(t *testing.T, itr Iterator) {
  26. assert.Panics(t, func() { itr.Next() }, "checkNextPanics expected panic but didn't")
  27. }
  28. func checkDomain(t *testing.T, itr Iterator, start, end []byte) {
  29. ds, de := itr.Domain()
  30. assert.Equal(t, start, ds, "checkDomain domain start incorrect")
  31. assert.Equal(t, end, de, "checkDomain domain end incorrect")
  32. }
  33. func checkItem(t *testing.T, itr Iterator, key []byte, value []byte) {
  34. k, v := itr.Key(), itr.Value()
  35. assert.Exactly(t, key, k)
  36. assert.Exactly(t, value, v)
  37. }
  38. func checkInvalid(t *testing.T, itr Iterator) {
  39. checkValid(t, itr, false)
  40. checkKeyPanics(t, itr)
  41. checkValuePanics(t, itr)
  42. checkNextPanics(t, itr)
  43. }
  44. func checkKeyPanics(t *testing.T, itr Iterator) {
  45. assert.Panics(t, func() { itr.Key() }, "checkKeyPanics expected panic but didn't")
  46. }
  47. func checkValuePanics(t *testing.T, itr Iterator) {
  48. assert.Panics(t, func() { itr.Key() }, "checkValuePanics expected panic but didn't")
  49. }
  50. func newTempDB(t *testing.T, backend DBBackendType) (db DB) {
  51. dir, dirname := cmn.Tempdir("db_common_test")
  52. db = NewDB("testdb", backend, dirname)
  53. dir.Close()
  54. return db
  55. }
  56. //----------------------------------------
  57. // mockDB
  58. // NOTE: not actually goroutine safe.
  59. // If you want something goroutine safe, maybe you just want a MemDB.
  60. type mockDB struct {
  61. mtx sync.Mutex
  62. calls map[string]int
  63. }
  64. func newMockDB() *mockDB {
  65. return &mockDB{
  66. calls: make(map[string]int),
  67. }
  68. }
  69. func (mdb *mockDB) Mutex() *sync.Mutex {
  70. return &(mdb.mtx)
  71. }
  72. func (mdb *mockDB) Get([]byte) []byte {
  73. mdb.calls["Get"]++
  74. return nil
  75. }
  76. func (mdb *mockDB) Has([]byte) bool {
  77. mdb.calls["Has"]++
  78. return false
  79. }
  80. func (mdb *mockDB) Set([]byte, []byte) {
  81. mdb.calls["Set"]++
  82. }
  83. func (mdb *mockDB) SetSync([]byte, []byte) {
  84. mdb.calls["SetSync"]++
  85. }
  86. func (mdb *mockDB) SetNoLock([]byte, []byte) {
  87. mdb.calls["SetNoLock"]++
  88. }
  89. func (mdb *mockDB) SetNoLockSync([]byte, []byte) {
  90. mdb.calls["SetNoLockSync"]++
  91. }
  92. func (mdb *mockDB) Delete([]byte) {
  93. mdb.calls["Delete"]++
  94. }
  95. func (mdb *mockDB) DeleteSync([]byte) {
  96. mdb.calls["DeleteSync"]++
  97. }
  98. func (mdb *mockDB) DeleteNoLock([]byte) {
  99. mdb.calls["DeleteNoLock"]++
  100. }
  101. func (mdb *mockDB) DeleteNoLockSync([]byte) {
  102. mdb.calls["DeleteNoLockSync"]++
  103. }
  104. func (mdb *mockDB) Iterator(start, end []byte) Iterator {
  105. mdb.calls["Iterator"]++
  106. return &mockIterator{}
  107. }
  108. func (mdb *mockDB) ReverseIterator(start, end []byte) Iterator {
  109. mdb.calls["ReverseIterator"]++
  110. return &mockIterator{}
  111. }
  112. func (mdb *mockDB) Close() {
  113. mdb.calls["Close"]++
  114. }
  115. func (mdb *mockDB) NewBatch() Batch {
  116. mdb.calls["NewBatch"]++
  117. return &memBatch{db: mdb}
  118. }
  119. func (mdb *mockDB) Print() {
  120. mdb.calls["Print"]++
  121. fmt.Printf("mockDB{%v}", mdb.Stats())
  122. }
  123. func (mdb *mockDB) Stats() map[string]string {
  124. mdb.calls["Stats"]++
  125. res := make(map[string]string)
  126. for key, count := range mdb.calls {
  127. res[key] = fmt.Sprintf("%d", count)
  128. }
  129. return res
  130. }
  131. //----------------------------------------
  132. // mockIterator
  133. type mockIterator struct{}
  134. func (mockIterator) Domain() (start []byte, end []byte) {
  135. return nil, nil
  136. }
  137. func (mockIterator) Valid() bool {
  138. return false
  139. }
  140. func (mockIterator) Next() {
  141. }
  142. func (mockIterator) Key() []byte {
  143. return nil
  144. }
  145. func (mockIterator) Value() []byte {
  146. return nil
  147. }
  148. func (mockIterator) Close() {
  149. }