You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
4.3 KiB

  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync/atomic"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestParallel(t *testing.T) {
  11. // Create tasks.
  12. var counter = new(int32)
  13. var tasks = make([]Task, 100*1000)
  14. for i := 0; i < len(tasks); i++ {
  15. tasks[i] = func(i int) (res interface{}, err error, abort bool) {
  16. atomic.AddInt32(counter, 1)
  17. return -1 * i, nil, false
  18. }
  19. }
  20. // Run in parallel.
  21. var trs, ok = Parallel(tasks...)
  22. assert.True(t, ok)
  23. // Verify.
  24. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already")
  25. var failedTasks int
  26. for i := 0; i < len(tasks); i++ {
  27. taskResult, ok := trs.LatestResult(i)
  28. if !ok {
  29. assert.Fail(t, "Task #%v did not complete.", i)
  30. failedTasks++
  31. } else if taskResult.Error != nil {
  32. assert.Fail(t, "Task should not have errored but got %v", taskResult.Error)
  33. failedTasks++
  34. } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) {
  35. assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int))
  36. failedTasks++
  37. } else {
  38. // Good!
  39. }
  40. }
  41. assert.Equal(t, failedTasks, 0, "No task should have failed")
  42. assert.Nil(t, trs.FirstError(), "There should be no errors")
  43. assert.Equal(t, 0, trs.FirstValue(), "First value should be 0")
  44. }
  45. func TestParallelAbort(t *testing.T) {
  46. var flow1 = make(chan struct{}, 1)
  47. var flow2 = make(chan struct{}, 1)
  48. var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking.
  49. var flow4 = make(chan struct{}, 1)
  50. // Create tasks.
  51. var tasks = []Task{
  52. func(i int) (res interface{}, err error, abort bool) {
  53. assert.Equal(t, i, 0)
  54. flow1 <- struct{}{}
  55. return 0, nil, false
  56. },
  57. func(i int) (res interface{}, err error, abort bool) {
  58. assert.Equal(t, i, 1)
  59. flow2 <- <-flow1
  60. return 1, errors.New("some error"), false
  61. },
  62. func(i int) (res interface{}, err error, abort bool) {
  63. assert.Equal(t, i, 2)
  64. flow3 <- <-flow2
  65. return 2, nil, true
  66. },
  67. func(i int) (res interface{}, err error, abort bool) {
  68. assert.Equal(t, i, 3)
  69. <-flow4
  70. return 3, nil, false
  71. },
  72. }
  73. // Run in parallel.
  74. var taskResultSet, ok = Parallel(tasks...)
  75. assert.False(t, ok, "ok should be false since we aborted task #2.")
  76. // Verify task #3.
  77. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send.
  78. waitTimeout(t, taskResultSet.chz[3], "Task #3")
  79. // Now let the last task (#3) complete after abort.
  80. flow4 <- <-flow3
  81. // Wait until all tasks have returned or panic'd.
  82. taskResultSet.Wait()
  83. // Verify task #0, #1, #2.
  84. checkResult(t, taskResultSet, 0, 0, nil, nil)
  85. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  86. checkResult(t, taskResultSet, 2, 2, nil, nil)
  87. checkResult(t, taskResultSet, 3, 3, nil, nil)
  88. }
  89. func TestParallelRecover(t *testing.T) {
  90. // Create tasks.
  91. var tasks = []Task{
  92. func(i int) (res interface{}, err error, abort bool) {
  93. return 0, nil, false
  94. },
  95. func(i int) (res interface{}, err error, abort bool) {
  96. return 1, errors.New("some error"), false
  97. },
  98. func(i int) (res interface{}, err error, abort bool) {
  99. panic(2)
  100. },
  101. }
  102. // Run in parallel.
  103. var taskResultSet, ok = Parallel(tasks...)
  104. assert.False(t, ok, "ok should be false since we panic'd in task #2.")
  105. // Verify task #0, #1, #2.
  106. checkResult(t, taskResultSet, 0, 0, nil, nil)
  107. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  108. checkResult(t, taskResultSet, 2, nil, nil, 2)
  109. }
  110. // Wait for result
  111. func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) {
  112. taskResult, ok := taskResultSet.LatestResult(index)
  113. taskName := fmt.Sprintf("Task #%v", index)
  114. assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName)
  115. assert.Equal(t, val, taskResult.Value, taskName)
  116. if err != nil {
  117. assert.Equal(t, err, taskResult.Error, taskName)
  118. } else if pnk != nil {
  119. assert.Equal(t, pnk, taskResult.Error.(Error).Data(), taskName)
  120. } else {
  121. assert.Nil(t, taskResult.Error, taskName)
  122. }
  123. }
  124. // Wait for timeout (no result)
  125. func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) {
  126. select {
  127. case _, ok := <-taskResultCh:
  128. if !ok {
  129. assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName)
  130. } else {
  131. assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName)
  132. }
  133. case <-time.After(1 * time.Second): // TODO use deterministic time?
  134. // Good!
  135. }
  136. }