You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
4.4 KiB

  1. package common
  2. import (
  3. "fmt"
  4. "sync/atomic"
  5. "testing"
  6. "time"
  7. "github.com/pkg/errors"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestParallel(t *testing.T) {
  11. // Create tasks.
  12. var counter = new(int32)
  13. var tasks = make([]Task, 100*1000)
  14. for i := 0; i < len(tasks); i++ {
  15. tasks[i] = func(i int) (res interface{}, abort bool, err error) {
  16. atomic.AddInt32(counter, 1)
  17. return -1 * i, false, nil
  18. }
  19. }
  20. // Run in parallel.
  21. var trs, ok = Parallel(tasks...)
  22. assert.True(t, ok)
  23. // Verify.
  24. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already")
  25. var failedTasks int
  26. for i := 0; i < len(tasks); i++ {
  27. taskResult, ok := trs.LatestResult(i)
  28. switch {
  29. case !ok:
  30. assert.Fail(t, "Task #%v did not complete.", i)
  31. failedTasks++
  32. case taskResult.Error != nil:
  33. assert.Fail(t, "Task should not have errored but got %v", taskResult.Error)
  34. failedTasks++
  35. case !assert.Equal(t, -1*i, taskResult.Value.(int)):
  36. assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int))
  37. failedTasks++
  38. }
  39. // else {
  40. // Good!
  41. // }
  42. }
  43. assert.Equal(t, failedTasks, 0, "No task should have failed")
  44. assert.Nil(t, trs.FirstError(), "There should be no errors")
  45. assert.Equal(t, 0, trs.FirstValue(), "First value should be 0")
  46. }
  47. func TestParallelAbort(t *testing.T) {
  48. var flow1 = make(chan struct{}, 1)
  49. var flow2 = make(chan struct{}, 1)
  50. var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking.
  51. var flow4 = make(chan struct{}, 1)
  52. // Create tasks.
  53. var tasks = []Task{
  54. func(i int) (res interface{}, abort bool, err error) {
  55. assert.Equal(t, i, 0)
  56. flow1 <- struct{}{}
  57. return 0, false, nil
  58. },
  59. func(i int) (res interface{}, abort bool, err error) {
  60. assert.Equal(t, i, 1)
  61. flow2 <- <-flow1
  62. return 1, false, errors.New("some error")
  63. },
  64. func(i int) (res interface{}, abort bool, err error) {
  65. assert.Equal(t, i, 2)
  66. flow3 <- <-flow2
  67. return 2, true, nil
  68. },
  69. func(i int) (res interface{}, abort bool, err error) {
  70. assert.Equal(t, i, 3)
  71. <-flow4
  72. return 3, false, nil
  73. },
  74. }
  75. // Run in parallel.
  76. var taskResultSet, ok = Parallel(tasks...)
  77. assert.False(t, ok, "ok should be false since we aborted task #2.")
  78. // Verify task #3.
  79. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send.
  80. waitTimeout(t, taskResultSet.chz[3], "Task #3")
  81. // Now let the last task (#3) complete after abort.
  82. flow4 <- <-flow3
  83. // Wait until all tasks have returned or panic'd.
  84. taskResultSet.Wait()
  85. // Verify task #0, #1, #2.
  86. checkResult(t, taskResultSet, 0, 0, nil, nil)
  87. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  88. checkResult(t, taskResultSet, 2, 2, nil, nil)
  89. checkResult(t, taskResultSet, 3, 3, nil, nil)
  90. }
  91. func TestParallelRecover(t *testing.T) {
  92. // Create tasks.
  93. var tasks = []Task{
  94. func(i int) (res interface{}, abort bool, err error) {
  95. return 0, false, nil
  96. },
  97. func(i int) (res interface{}, abort bool, err error) {
  98. return 1, false, errors.New("some error")
  99. },
  100. func(i int) (res interface{}, abort bool, err error) {
  101. panic(2)
  102. },
  103. }
  104. // Run in parallel.
  105. var taskResultSet, ok = Parallel(tasks...)
  106. assert.False(t, ok, "ok should be false since we panic'd in task #2.")
  107. // Verify task #0, #1, #2.
  108. checkResult(t, taskResultSet, 0, 0, nil, nil)
  109. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  110. checkResult(t, taskResultSet, 2, nil, nil, errors.Errorf("panic in task %v", 2).Error())
  111. }
  112. // Wait for result
  113. func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int,
  114. val interface{}, err error, pnk interface{}) {
  115. taskResult, ok := taskResultSet.LatestResult(index)
  116. taskName := fmt.Sprintf("Task #%v", index)
  117. assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName)
  118. assert.Equal(t, val, taskResult.Value, taskName)
  119. switch {
  120. case err != nil:
  121. assert.Equal(t, err.Error(), taskResult.Error.Error(), taskName)
  122. case pnk != nil:
  123. assert.Equal(t, pnk, taskResult.Error.Error(), taskName)
  124. default:
  125. assert.Nil(t, taskResult.Error, taskName)
  126. }
  127. }
  128. // Wait for timeout (no result)
  129. func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) {
  130. select {
  131. case _, ok := <-taskResultCh:
  132. if !ok {
  133. assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName)
  134. } else {
  135. assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName)
  136. }
  137. case <-time.After(1 * time.Second): // TODO use deterministic time?
  138. // Good!
  139. }
  140. }