You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
4.3 KiB

  1. package common
  2. import (
  3. "errors"
  4. "fmt"
  5. "sync/atomic"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. )
  10. func TestParallel(t *testing.T) {
  11. // Create tasks.
  12. var counter = new(int32)
  13. var tasks = make([]Task, 100*1000)
  14. for i := 0; i < len(tasks); i++ {
  15. tasks[i] = func(i int) (res interface{}, err error, abort bool) {
  16. atomic.AddInt32(counter, 1)
  17. return -1 * i, nil, false
  18. }
  19. }
  20. // Run in parallel.
  21. var trs, ok = Parallel(tasks...)
  22. assert.True(t, ok)
  23. // Verify.
  24. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already")
  25. var failedTasks int
  26. for i := 0; i < len(tasks); i++ {
  27. taskResult, ok := trs.LatestResult(i)
  28. if !ok {
  29. assert.Fail(t, "Task #%v did not complete.", i)
  30. failedTasks++
  31. } else if taskResult.Error != nil {
  32. assert.Fail(t, "Task should not have errored but got %v", taskResult.Error)
  33. failedTasks++
  34. } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) {
  35. assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int))
  36. failedTasks++
  37. }
  38. // else {
  39. // Good!
  40. // }
  41. }
  42. assert.Equal(t, failedTasks, 0, "No task should have failed")
  43. assert.Nil(t, trs.FirstError(), "There should be no errors")
  44. assert.Equal(t, 0, trs.FirstValue(), "First value should be 0")
  45. }
  46. func TestParallelAbort(t *testing.T) {
  47. var flow1 = make(chan struct{}, 1)
  48. var flow2 = make(chan struct{}, 1)
  49. var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking.
  50. var flow4 = make(chan struct{}, 1)
  51. // Create tasks.
  52. var tasks = []Task{
  53. func(i int) (res interface{}, err error, abort bool) {
  54. assert.Equal(t, i, 0)
  55. flow1 <- struct{}{}
  56. return 0, nil, false
  57. },
  58. func(i int) (res interface{}, err error, abort bool) {
  59. assert.Equal(t, i, 1)
  60. flow2 <- <-flow1
  61. return 1, errors.New("some error"), false
  62. },
  63. func(i int) (res interface{}, err error, abort bool) {
  64. assert.Equal(t, i, 2)
  65. flow3 <- <-flow2
  66. return 2, nil, true
  67. },
  68. func(i int) (res interface{}, err error, abort bool) {
  69. assert.Equal(t, i, 3)
  70. <-flow4
  71. return 3, nil, false
  72. },
  73. }
  74. // Run in parallel.
  75. var taskResultSet, ok = Parallel(tasks...)
  76. assert.False(t, ok, "ok should be false since we aborted task #2.")
  77. // Verify task #3.
  78. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send.
  79. waitTimeout(t, taskResultSet.chz[3], "Task #3")
  80. // Now let the last task (#3) complete after abort.
  81. flow4 <- <-flow3
  82. // Wait until all tasks have returned or panic'd.
  83. taskResultSet.Wait()
  84. // Verify task #0, #1, #2.
  85. checkResult(t, taskResultSet, 0, 0, nil, nil)
  86. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  87. checkResult(t, taskResultSet, 2, 2, nil, nil)
  88. checkResult(t, taskResultSet, 3, 3, nil, nil)
  89. }
  90. func TestParallelRecover(t *testing.T) {
  91. // Create tasks.
  92. var tasks = []Task{
  93. func(i int) (res interface{}, err error, abort bool) {
  94. return 0, nil, false
  95. },
  96. func(i int) (res interface{}, err error, abort bool) {
  97. return 1, errors.New("some error"), false
  98. },
  99. func(i int) (res interface{}, err error, abort bool) {
  100. panic(2)
  101. },
  102. }
  103. // Run in parallel.
  104. var taskResultSet, ok = Parallel(tasks...)
  105. assert.False(t, ok, "ok should be false since we panic'd in task #2.")
  106. // Verify task #0, #1, #2.
  107. checkResult(t, taskResultSet, 0, 0, nil, nil)
  108. checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil)
  109. checkResult(t, taskResultSet, 2, nil, nil, 2)
  110. }
  111. // Wait for result
  112. func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) {
  113. taskResult, ok := taskResultSet.LatestResult(index)
  114. taskName := fmt.Sprintf("Task #%v", index)
  115. assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName)
  116. assert.Equal(t, val, taskResult.Value, taskName)
  117. if err != nil {
  118. assert.Equal(t, err, taskResult.Error, taskName)
  119. } else if pnk != nil {
  120. assert.Equal(t, pnk, taskResult.Error.(Error).Data(), taskName)
  121. } else {
  122. assert.Nil(t, taskResult.Error, taskName)
  123. }
  124. }
  125. // Wait for timeout (no result)
  126. func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) {
  127. select {
  128. case _, ok := <-taskResultCh:
  129. if !ok {
  130. assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName)
  131. } else {
  132. assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName)
  133. }
  134. case <-time.After(1 * time.Second): // TODO use deterministic time?
  135. // Good!
  136. }
  137. }