You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

151 lines
4.3 KiB

  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync/atomic"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestParallel(t *testing.T) {
  11. // Create tasks.
  12. var counter = new(int32)
  13. var tasks = make([]Task, 100*1000)
  14. for i := 0; i < len(tasks); i++ {
  15. tasks[i] = func(i int) (res interface{}, err error, abort bool) {
  16. atomic.AddInt32(counter, 1)
  17. return -1 * i, nil, false
  18. }
  19. }
  20. // Run in parallel.
  21. var trs, ok = Parallel(tasks...)
  22. assert.True(t, ok)
  23. // Verify.
  24. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already")
  25. var failedTasks int
  26. for i := 0; i < len(tasks); i++ {
  27. taskResult, ok := trs.LatestResult(i)
  28. if !ok {
  29. assert.Fail(t, "Task #%v did not complete.", i)
  30. failedTasks += 1
  31. } else if taskResult.Error != nil {
  32. assert.Fail(t, "Task should not have errored but got %v", taskResult.Error)
  33. failedTasks += 1
  34. } else if taskResult.Panic != nil {
  35. assert.Fail(t, "Task should not have panic'd but got %v", taskResult.Panic)
  36. failedTasks += 1
  37. } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) {
  38. assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int))
  39. failedTasks += 1
  40. } else {
  41. // Good!
  42. }
  43. }
  44. assert.Equal(t, failedTasks, 0, "No task should have failed")
  45. assert.Nil(t, trs.FirstError(), "There should be no errors")
  46. assert.Nil(t, trs.FirstPanic(), "There should be no panics")
  47. assert.Equal(t, 0, trs.FirstValue(), "First value should be 0")
  48. }
  49. func TestParallelAbort(t *testing.T) {
  50. var flow1 = make(chan struct{}, 1)
  51. var flow2 = make(chan struct{}, 1)
  52. var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking.
  53. var flow4 = make(chan struct{}, 1)
  54. // Create tasks.
  55. var tasks = []Task{
  56. func(i int) (res interface{}, err error, abort bool) {
  57. assert.Equal(t, i, 0)
  58. flow1 <- struct{}{}
  59. return 0, nil, false
  60. },
  61. func(i int) (res interface{}, err error, abort bool) {
  62. assert.Equal(t, i, 1)
  63. flow2 <- <-flow1
  64. return 1, errors.New("some error"), false
  65. },
  66. func(i int) (res interface{}, err error, abort bool) {
  67. assert.Equal(t, i, 2)
  68. flow3 <- <-flow2
  69. return 2, nil, true
  70. },
  71. func(i int) (res interface{}, err error, abort bool) {
  72. assert.Equal(t, i, 3)
  73. <-flow4
  74. return 3, nil, false
  75. },
  76. }
  77. // Run in parallel.
  78. var taskResultSet, ok = Parallel(tasks...)
  79. assert.False(t, ok, "ok should be false since we aborted task #2.")
  80. // Verify task #3.
  81. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send.
  82. waitTimeout(t, taskResultSet.chz[3], "Task #3")
  83. // Now let the last task (#3) complete after abort.
  84. flow4 <- <-flow3
  85. // Verify task #0, #1, #2.
  86. checkResult(t, taskResultSet, 0, 0, nil, nil)
  87. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  88. checkResult(t, taskResultSet, 2, 2, nil, nil)
  89. }
  90. func TestParallelRecover(t *testing.T) {
  91. // Create tasks.
  92. var tasks = []Task{
  93. func(i int) (res interface{}, err error, abort bool) {
  94. return 0, nil, false
  95. },
  96. func(i int) (res interface{}, err error, abort bool) {
  97. return 1, errors.New("some error"), false
  98. },
  99. func(i int) (res interface{}, err error, abort bool) {
  100. panic(2)
  101. },
  102. }
  103. // Run in parallel.
  104. var taskResultSet, ok = Parallel(tasks...)
  105. assert.False(t, ok, "ok should be false since we panic'd in task #2.")
  106. // Verify task #0, #1, #2.
  107. checkResult(t, taskResultSet, 0, 0, nil, nil)
  108. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  109. checkResult(t, taskResultSet, 2, nil, nil, 2)
  110. }
  111. // Wait for result
  112. func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) {
  113. taskResult, ok := taskResultSet.LatestResult(index)
  114. taskName := fmt.Sprintf("Task #%v", index)
  115. assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName)
  116. assert.Equal(t, val, taskResult.Value, taskName)
  117. assert.Equal(t, err, taskResult.Error, taskName)
  118. assert.Equal(t, pnk, taskResult.Panic, taskName)
  119. }
  120. // Wait for timeout (no result)
  121. func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) {
  122. select {
  123. case _, ok := <-taskResultCh:
  124. if !ok {
  125. assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName)
  126. } else {
  127. assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName)
  128. }
  129. case <-time.After(1 * time.Second): // TODO use deterministic time?
  130. // Good!
  131. }
  132. }