diff --git a/types/part_set_test.go b/types/part_set_test.go index 01437f05e..3576e747e 100644 --- a/types/part_set_test.go +++ b/types/part_set_test.go @@ -1,10 +1,12 @@ package types import ( - "bytes" "io/ioutil" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + cmn "github.com/tendermint/tendermint/libs/common" ) @@ -13,24 +15,21 @@ const ( ) func TestBasicPartSet(t *testing.T) { - // Construct random data of size partSize * 100 data := cmn.RandBytes(testPartSize * 100) - partSet := NewPartSetFromData(data, testPartSize) - if len(partSet.Hash()) == 0 { - t.Error("Expected to get hash") - } - if partSet.Total() != 100 { - t.Errorf("Expected to get 100 parts, but got %v", partSet.Total()) - } - if !partSet.IsComplete() { - t.Errorf("PartSet should be complete") - } + + assert.NotEmpty(t, partSet.Hash()) + assert.Equal(t, 100, partSet.Total()) + assert.Equal(t, 100, partSet.BitArray().Size()) + assert.True(t, partSet.HashesTo(partSet.Hash())) + assert.True(t, partSet.IsComplete()) + assert.Equal(t, 100, partSet.Count()) // Test adding parts to a new partSet. partSet2 := NewPartSetFromHeader(partSet.Header()) + assert.True(t, partSet2.HasHeader(partSet.Header())) for i := 0; i < partSet.Total(); i++ { part := partSet.GetPart(i) //t.Logf("\n%v", part) @@ -39,31 +38,28 @@ func TestBasicPartSet(t *testing.T) { t.Errorf("Failed to add part %v, error: %v", i, err) } } - - if !bytes.Equal(partSet.Hash(), partSet2.Hash()) { - t.Error("Expected to get same hash") - } - if partSet2.Total() != 100 { - t.Errorf("Expected to get 100 parts, but got %v", partSet2.Total()) - } - if !partSet2.IsComplete() { - t.Errorf("Reconstructed PartSet should be complete") - } + // adding part with invalid index + added, err := partSet2.AddPart(&Part{Index: 10000}) + assert.False(t, added) + assert.Error(t, err) + // adding existing part + added, err = partSet2.AddPart(partSet2.GetPart(0)) + assert.False(t, added) + assert.Nil(t, err) + + assert.Equal(t, partSet.Hash(), partSet2.Hash()) + assert.Equal(t, 100, partSet2.Total()) + assert.True(t, partSet2.IsComplete()) // Reconstruct data, assert that they are equal. data2Reader := partSet2.GetReader() data2, err := ioutil.ReadAll(data2Reader) - if err != nil { - t.Errorf("Error reading data2Reader: %v", err) - } - if !bytes.Equal(data, data2) { - t.Errorf("Got wrong data.") - } + require.NoError(t, err) + assert.Equal(t, data, data2) } func TestWrongProof(t *testing.T) { - // Construct random data of size partSize * 100 data := cmn.RandBytes(testPartSize * 100) partSet := NewPartSetFromData(data, testPartSize) @@ -86,5 +82,4 @@ func TestWrongProof(t *testing.T) { if added || err == nil { t.Errorf("Expected to fail adding a part with bad bytes.") } - }