diff --git a/connection_test.go b/connection_test.go index 9d489f655..79a7227a2 100644 --- a/connection_test.go +++ b/connection_test.go @@ -103,3 +103,36 @@ func TestMConnectionStatus(t *testing.T) { assert.NotNil(status) assert.Zero(status.Channels[0].SendQueueSize) } + +func TestMConnectionNonPersistent(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + receivedCh := make(chan []byte) + errorsCh := make(chan interface{}) + onReceive := func(chID byte, msgBytes []byte) { + receivedCh <- msgBytes + } + onError := func(r interface{}) { + errorsCh <- r + } + mconn := createMConnectionWithCallbacks(client, onReceive, onError) + _, err := mconn.Start() + require.Nil(err) + defer mconn.Stop() + + client.Close() + + select { + case receivedBytes := <-receivedCh: + t.Fatalf("Expected error, got %v", receivedBytes) + case err := <-errorsCh: + assert.NotNil(err) + assert.False(mconn.IsRunning()) + case <-time.After(500 * time.Millisecond): + t.Fatal("Did not receive error in 500ms") + } +}