package common import ( "errors" "fmt" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestParallel(t *testing.T) { // Create tasks. var counter = new(int32) var tasks = make([]Task, 100*1000) for i := 0; i < len(tasks); i++ { tasks[i] = func(i int) (res interface{}, err error, abort bool) { atomic.AddInt32(counter, 1) return -1 * i, nil, false } } // Run in parallel. var trs, ok = Parallel(tasks...) assert.True(t, ok) // Verify. assert.Equal(t, int(*counter), len(tasks), "Each task should have incremented the counter already") var failedTasks int for i := 0; i < len(tasks); i++ { taskResult, ok := trs.LatestResult(i) if !ok { assert.Fail(t, "Task #%v did not complete.", i) failedTasks++ } else if taskResult.Error != nil { assert.Fail(t, "Task should not have errored but got %v", taskResult.Error) failedTasks++ } else if !assert.Equal(t, -1*i, taskResult.Value.(int)) { assert.Fail(t, "Task should have returned %v but got %v", -1*i, taskResult.Value.(int)) failedTasks++ } // else { // Good! // } } assert.Equal(t, failedTasks, 0, "No task should have failed") assert.Nil(t, trs.FirstError(), "There should be no errors") assert.Equal(t, 0, trs.FirstValue(), "First value should be 0") } func TestParallelAbort(t *testing.T) { var flow1 = make(chan struct{}, 1) var flow2 = make(chan struct{}, 1) var flow3 = make(chan struct{}, 1) // Cap must be > 0 to prevent blocking. var flow4 = make(chan struct{}, 1) // Create tasks. var tasks = []Task{ func(i int) (res interface{}, err error, abort bool) { assert.Equal(t, i, 0) flow1 <- struct{}{} return 0, nil, false }, func(i int) (res interface{}, err error, abort bool) { assert.Equal(t, i, 1) flow2 <- <-flow1 return 1, errors.New("some error"), false }, func(i int) (res interface{}, err error, abort bool) { assert.Equal(t, i, 2) flow3 <- <-flow2 return 2, nil, true }, func(i int) (res interface{}, err error, abort bool) { assert.Equal(t, i, 3) <-flow4 return 3, nil, false }, } // Run in parallel. var taskResultSet, ok = Parallel(tasks...) assert.False(t, ok, "ok should be false since we aborted task #2.") // Verify task #3. // Initially taskResultSet.chz[3] sends nothing since flow4 didn't send. waitTimeout(t, taskResultSet.chz[3], "Task #3") // Now let the last task (#3) complete after abort. flow4 <- <-flow3 // Wait until all tasks have returned or panic'd. taskResultSet.Wait() // Verify task #0, #1, #2. checkResult(t, taskResultSet, 0, 0, nil, nil) checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) checkResult(t, taskResultSet, 2, 2, nil, nil) checkResult(t, taskResultSet, 3, 3, nil, nil) } func TestParallelRecover(t *testing.T) { // Create tasks. var tasks = []Task{ func(i int) (res interface{}, err error, abort bool) { return 0, nil, false }, func(i int) (res interface{}, err error, abort bool) { return 1, errors.New("some error"), false }, func(i int) (res interface{}, err error, abort bool) { panic(2) }, } // Run in parallel. var taskResultSet, ok = Parallel(tasks...) assert.False(t, ok, "ok should be false since we panic'd in task #2.") // Verify task #0, #1, #2. checkResult(t, taskResultSet, 0, 0, nil, nil) checkResult(t, taskResultSet, 1, 1, errors.New("some error"), nil) checkResult(t, taskResultSet, 2, nil, nil, 2) } // Wait for result func checkResult(t *testing.T, taskResultSet *TaskResultSet, index int, val interface{}, err error, pnk interface{}) { taskResult, ok := taskResultSet.LatestResult(index) taskName := fmt.Sprintf("Task #%v", index) assert.True(t, ok, "TaskResultCh unexpectedly closed for %v", taskName) assert.Equal(t, val, taskResult.Value, taskName) if err != nil { assert.Equal(t, err, taskResult.Error, taskName) } else if pnk != nil { assert.Equal(t, pnk, taskResult.Error.(Error).Data(), taskName) } else { assert.Nil(t, taskResult.Error, taskName) } } // Wait for timeout (no result) func waitTimeout(t *testing.T, taskResultCh TaskResultCh, taskName string) { select { case _, ok := <-taskResultCh: if !ok { assert.Fail(t, "TaskResultCh unexpectedly closed (%v)", taskName) } else { assert.Fail(t, "TaskResultCh unexpectedly returned for %v", taskName) } case <-time.After(1 * time.Second): // TODO use deterministic time? // Good! } }