You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.2 KiB

  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync/atomic"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestParallel(t *testing.T) {
  11. // Create tasks.
  12. var counter = new(int32)
  13. var tasks = make([]Task, 100*1000)
  14. for i := 0; i < len(tasks); i++ {
  15. tasks[i] = func(i int) (res interface{}, err error, abort bool) {
  16. atomic.AddInt32(counter, 1)
  17. return -1 * i, nil, false
  18. }
  19. }
  20. // Run in parallel.
  21. var trs, ok = Parallel(tasks...)
  22. assert.True(t, ok)
  23. // Verify.
  24. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already")
  25. var failedTasks int
  26. for i := 0; i < len(tasks); i++ {
  27. taskResult, ok := trs.LatestResult(i)
  28. if !ok {
  29. assert.Fail(t, "Task #%v did not complete.", i)
  30. failedTasks++
  31. } else if taskResult.Error != nil {
  32. assert.Fail(t, "Task should not have errored but got %v", taskResult.Error)
  33. failedTasks++
  34. } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) {
  35. assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int))
  36. failedTasks++
  37. } else {
  38. // Good!
  39. }
  40. }
  41. assert.Equal(t, failedTasks, 0, "No task should have failed")
  42. assert.Nil(t, trs.FirstError(), "There should be no errors")
  43. assert.Equal(t, 0, trs.FirstValue(), "First value should be 0")
  44. }
  45. func TestParallelAbort(t *testing.T) {
  46. var flow1 = make(chan struct{}, 1)
  47. var flow2 = make(chan struct{}, 1)
  48. var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking.
  49. var flow4 = make(chan struct{}, 1)
  50. // Create tasks.
  51. var tasks = []Task{
  52. func(i int) (res interface{}, err error, abort bool) {
  53. assert.Equal(t, i, 0)
  54. flow1 <- struct{}{}
  55. return 0, nil, false
  56. },
  57. func(i int) (res interface{}, err error, abort bool) {
  58. assert.Equal(t, i, 1)
  59. flow2 <- <-flow1
  60. return 1, errors.New("some error"), false
  61. },
  62. func(i int) (res interface{}, err error, abort bool) {
  63. assert.Equal(t, i, 2)
  64. flow3 <- <-flow2
  65. return 2, nil, true
  66. },
  67. func(i int) (res interface{}, err error, abort bool) {
  68. assert.Equal(t, i, 3)
  69. <-flow4
  70. return 3, nil, false
  71. },
  72. }
  73. // Run in parallel.
  74. var taskResultSet, ok = Parallel(tasks...)
  75. assert.False(t, ok, "ok should be false since we aborted task #2.")
  76. // Verify task #3.
  77. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send.
  78. waitTimeout(t, taskResultSet.chz[3], "Task #3")
  79. // Now let the last task (#3) complete after abort.
  80. flow4 <- <-flow3
  81. // Verify task #0, #1, #2.
  82. checkResult(t, taskResultSet, 0, 0, nil, nil)
  83. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  84. checkResult(t, taskResultSet, 2, 2, nil, nil)
  85. }
  86. func TestParallelRecover(t *testing.T) {
  87. // Create tasks.
  88. var tasks = []Task{
  89. func(i int) (res interface{}, err error, abort bool) {
  90. return 0, nil, false
  91. },
  92. func(i int) (res interface{}, err error, abort bool) {
  93. return 1, errors.New("some error"), false
  94. },
  95. func(i int) (res interface{}, err error, abort bool) {
  96. panic(2)
  97. },
  98. }
  99. // Run in parallel.
  100. var taskResultSet, ok = Parallel(tasks...)
  101. assert.False(t, ok, "ok should be false since we panic'd in task #2.")
  102. // Verify task #0, #1, #2.
  103. checkResult(t, taskResultSet, 0, 0, nil, nil)
  104. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  105. checkResult(t, taskResultSet, 2, nil, nil, 2)
  106. }
  107. // Wait for result
  108. func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) {
  109. taskResult, ok := taskResultSet.LatestResult(index)
  110. taskName := fmt.Sprintf("Task #%v", index)
  111. assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName)
  112. assert.Equal(t, val, taskResult.Value, taskName)
  113. if err != nil {
  114. assert.Equal(t, err, taskResult.Error, taskName)
  115. } else if pnk != nil {
  116. assert.Equal(t, pnk, taskResult.Error.(Error).Cause(), taskName)
  117. } else {
  118. assert.Nil(t, taskResult.Error, taskName)
  119. }
  120. }
  121. // Wait for timeout (no result)
  122. func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) {
  123. select {
  124. case _, ok := <-taskResultCh:
  125. if !ok {
  126. assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName)
  127. } else {
  128. assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName)
  129. }
  130. case <-time.After(1 * time.Second): // TODO use deterministic time?
  131. // Good!
  132. }
  133. }