From 97d47b52639280d87be8c3b492c8640c6b69d96f Mon Sep 17 00:00:00 2001 From: JayT106 Date: Fri, 4 Feb 2022 04:04:59 +0800 Subject: [PATCH] mempool: IDs issue fixes (#7763) --- internal/mempool/ids.go | 8 ++++ internal/mempool/ids_test.go | 69 +++++++++++++++++++++++++++++++- internal/mempool/reactor_test.go | 8 ++-- 3 files changed, 80 insertions(+), 5 deletions(-) diff --git a/internal/mempool/ids.go b/internal/mempool/ids.go index 3788afcbc..8b171e48a 100644 --- a/internal/mempool/ids.go +++ b/internal/mempool/ids.go @@ -30,6 +30,11 @@ func (ids *IDs) ReserveForPeer(peerID types.NodeID) { ids.mtx.Lock() defer ids.mtx.Unlock() + if _, ok := ids.peerMap[peerID]; ok { + // the peer has been reserved + return + } + curID := ids.nextPeerID() ids.peerMap[peerID] = curID ids.activeIDs[curID] = struct{}{} @@ -44,6 +49,9 @@ func (ids *IDs) Reclaim(peerID types.NodeID) { if ok { delete(ids.activeIDs, removedID) delete(ids.peerMap, peerID) + if removedID < ids.nextID { + ids.nextID = removedID + } } } diff --git a/internal/mempool/ids_test.go b/internal/mempool/ids_test.go index a39838627..006ad5ced 100644 --- a/internal/mempool/ids_test.go +++ b/internal/mempool/ids_test.go @@ -12,12 +12,77 @@ func TestMempoolIDsBasic(t *testing.T) { peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") require.NoError(t, err) + require.EqualValues(t, 0, ids.GetForPeer(peerID)) ids.ReserveForPeer(peerID) require.EqualValues(t, 1, ids.GetForPeer(peerID)) + ids.Reclaim(peerID) + require.EqualValues(t, 0, ids.GetForPeer(peerID)) ids.ReserveForPeer(peerID) - require.EqualValues(t, 2, ids.GetForPeer(peerID)) - ids.Reclaim(peerID) + require.EqualValues(t, 1, ids.GetForPeer(peerID)) +} + +func TestMempoolIDsPeerDupReserve(t *testing.T) { + ids := NewMempoolIDs() + + peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") + require.NoError(t, err) + require.EqualValues(t, 0, ids.GetForPeer(peerID)) + + ids.ReserveForPeer(peerID) + require.EqualValues(t, 1, ids.GetForPeer(peerID)) + + ids.ReserveForPeer(peerID) + require.EqualValues(t, 1, ids.GetForPeer(peerID)) +} + +func TestMempoolIDs2Peers(t *testing.T) { + ids := NewMempoolIDs() + + peer1ID, _ := types.NewNodeID("0011223344556677889900112233445566778899") + require.EqualValues(t, 0, ids.GetForPeer(peer1ID)) + + ids.ReserveForPeer(peer1ID) + require.EqualValues(t, 1, ids.GetForPeer(peer1ID)) + + ids.Reclaim(peer1ID) + require.EqualValues(t, 0, ids.GetForPeer(peer1ID)) + + peer2ID, _ := types.NewNodeID("1011223344556677889900112233445566778899") + + ids.ReserveForPeer(peer2ID) + require.EqualValues(t, 1, ids.GetForPeer(peer2ID)) + + ids.ReserveForPeer(peer1ID) + require.EqualValues(t, 2, ids.GetForPeer(peer1ID)) +} + +func TestMempoolIDsNextExistID(t *testing.T) { + ids := NewMempoolIDs() + + peer1ID, _ := types.NewNodeID("0011223344556677889900112233445566778899") + ids.ReserveForPeer(peer1ID) + require.EqualValues(t, 1, ids.GetForPeer(peer1ID)) + + peer2ID, _ := types.NewNodeID("1011223344556677889900112233445566778899") + ids.ReserveForPeer(peer2ID) + require.EqualValues(t, 2, ids.GetForPeer(peer2ID)) + + peer3ID, _ := types.NewNodeID("2011223344556677889900112233445566778899") + ids.ReserveForPeer(peer3ID) + require.EqualValues(t, 3, ids.GetForPeer(peer3ID)) + + ids.Reclaim(peer1ID) + require.EqualValues(t, 0, ids.GetForPeer(peer1ID)) + + ids.Reclaim(peer3ID) + require.EqualValues(t, 0, ids.GetForPeer(peer3ID)) + + ids.ReserveForPeer(peer1ID) + require.EqualValues(t, 1, ids.GetForPeer(peer1ID)) + + ids.ReserveForPeer(peer3ID) + require.EqualValues(t, 3, ids.GetForPeer(peer3ID)) } diff --git a/internal/mempool/reactor_test.go b/internal/mempool/reactor_test.go index d99b27edb..cddbc3be8 100644 --- a/internal/mempool/reactor_test.go +++ b/internal/mempool/reactor_test.go @@ -2,6 +2,7 @@ package mempool import ( "context" + "fmt" "os" "runtime" "strings" @@ -370,13 +371,14 @@ func TestMempoolIDsPanicsIfNodeRequestsOvermaxActiveIDs(t *testing.T) { // 0 is already reserved for UnknownPeerID ids := NewMempoolIDs() - peerID, err := types.NewNodeID("0011223344556677889900112233445566778899") - require.NoError(t, err) - for i := 0; i < MaxActiveIDs-1; i++ { + peerID, err := types.NewNodeID(fmt.Sprintf("%040d", i)) + require.NoError(t, err) ids.ReserveForPeer(peerID) } + peerID, err := types.NewNodeID(fmt.Sprintf("%040d", MaxActiveIDs-1)) + require.NoError(t, err) require.Panics(t, func() { ids.ReserveForPeer(peerID) })