From fd9e54a081a26e1357aeced3a65625335badecb7 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 8 May 2026 17:30:57 +0530 Subject: [PATCH 1/3] add: MaxFrameSize to p2p network --- .../tss/networking/libp2p/network.go | 18 ++++- .../tss/networking/libp2p/network_test.go | 80 +++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 universalClient/tss/networking/libp2p/network_test.go diff --git a/universalClient/tss/networking/libp2p/network.go b/universalClient/tss/networking/libp2p/network.go index 1d3f1d6f..8710940c 100644 --- a/universalClient/tss/networking/libp2p/network.go +++ b/universalClient/tss/networking/libp2p/network.go @@ -25,6 +25,12 @@ import ( "github.com/pushchain/push-chain-node/universalClient/tss/networking" ) +// MaxFrameSize bounds a single length-prefixed frame on TSS streams. The cap +// rejects oversize length prefixes before allocation so a peer cannot trigger +// large attacker-chosen heap allocations. Sized well above the largest +// observed DKLS Step() + coordinator.Message wrapping for our committee sizes. +const MaxFrameSize = 1 * 1024 * 1024 // 1 MiB + // Network implements networking.Network using libp2p. type Network struct { cfg Config @@ -224,6 +230,9 @@ func loadIdentity(base64Key string) (crypto.PrivKey, error) { } func writeFramed(w io.Writer, data []byte) error { + if len(data) > MaxFrameSize { + return fmt.Errorf("frame size %d exceeds maximum %d", len(data), MaxFrameSize) + } bw := bufio.NewWriter(w) if err := binary.Write(bw, binary.BigEndian, uint32(len(data))); err != nil { return err @@ -235,11 +244,18 @@ func writeFramed(w io.Writer, data []byte) error { } func readFramed(r io.Reader) ([]byte, error) { - br := bufio.NewReader(r) + // Cap the underlying reader at MaxFrameSize+4 (4 bytes length prefix + + // payload) as defense-in-depth: even if the explicit length check below is + // ever bypassed by a future change, the reader cannot consume more than + // this many bytes from the peer. + br := bufio.NewReader(io.LimitReader(r, int64(MaxFrameSize)+4)) var length uint32 if err := binary.Read(br, binary.BigEndian, &length); err != nil { return nil, err } + if length > MaxFrameSize { + return nil, fmt.Errorf("frame size %d exceeds maximum %d", length, MaxFrameSize) + } buf := make([]byte, length) if _, err := io.ReadFull(br, buf); err != nil { return nil, err diff --git a/universalClient/tss/networking/libp2p/network_test.go b/universalClient/tss/networking/libp2p/network_test.go new file mode 100644 index 00000000..8c984668 --- /dev/null +++ b/universalClient/tss/networking/libp2p/network_test.go @@ -0,0 +1,80 @@ +package libp2p + +import ( + "bytes" + "encoding/binary" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestReadFramed_RoundTrip(t *testing.T) { + payload := []byte("hello tss") + var buf bytes.Buffer + require.NoError(t, writeFramed(&buf, payload)) + + got, err := readFramed(&buf) + require.NoError(t, err) + assert.Equal(t, payload, got) +} + +func TestReadFramed_RejectsOversizeLengthPrefix(t *testing.T) { + // Craft a frame whose length prefix claims more than MaxFrameSize. + // readFramed must reject before allocating MaxFrameSize+1 bytes. + var buf bytes.Buffer + require.NoError(t, binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1))) + // No payload bytes follow — readFramed should fail on the length check + // before attempting to read the (non-existent) body. + + _, err := readFramed(&buf) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") +} + +func TestReadFramed_AcceptsAtMaxFrameSize(t *testing.T) { + // Boundary: a frame of exactly MaxFrameSize bytes must be accepted. + // We don't actually allocate 16 MiB in the test buffer; instead we + // validate the length-check path with a reader that returns EOF after + // the length prefix and assert the failure mode is the read error, + // not the size-cap error. + var buf bytes.Buffer + require.NoError(t, binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize))) + + _, err := readFramed(&buf) + require.Error(t, err) + // Should be EOF/UnexpectedEOF on the body read, NOT the size-cap rejection. + assert.NotContains(t, err.Error(), "exceeds maximum") + assert.True(t, err == io.EOF || err == io.ErrUnexpectedEOF, "expected EOF on truncated body, got: %v", err) +} + +func TestWriteFramed_RejectsOversizePayload(t *testing.T) { + // writeFramed must symmetric-cap so a misbehaving local sender cannot + // produce a frame that the receiving peer would itself reject. Avoids + // silent protocol drops where the wire format crosses the line. + oversize := make([]byte, MaxFrameSize+1) + var buf bytes.Buffer + err := writeFramed(&buf, oversize) + require.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + // Buffer must not contain partial data — the check happens before any write. + assert.Equal(t, 0, buf.Len(), "writeFramed must not emit any bytes when rejecting") +} + +func TestWriteFramed_AcceptsAtMaxFrameSize(t *testing.T) { + // Boundary: payload of exactly MaxFrameSize must round-trip. + payload := make([]byte, MaxFrameSize) + for i := range payload { + payload[i] = byte(i % 256) + } + var buf bytes.Buffer + require.NoError(t, writeFramed(&buf, payload)) + + got, err := readFramed(&buf) + require.NoError(t, err) + assert.Equal(t, len(payload), len(got)) + assert.Equal(t, payload[0], got[0]) + assert.Equal(t, payload[len(payload)-1], got[len(got)-1]) +} + From 82aab9c69f5a05688f0368814873db78a5d859c0 Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 8 May 2026 17:31:54 +0530 Subject: [PATCH 2/3] refactor: move coordinator check up so malicious peer req are rejected sooner --- .../tss/sessionmanager/sessionmanager.go | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/universalClient/tss/sessionmanager/sessionmanager.go b/universalClient/tss/sessionmanager/sessionmanager.go index 46c8aa36..03123916 100644 --- a/universalClient/tss/sessionmanager/sessionmanager.go +++ b/universalClient/tss/sessionmanager/sessionmanager.go @@ -33,11 +33,11 @@ type SendFunc func(ctx context.Context, peerID string, data []byte) error // sessionState holds all state for a single session. type sessionState struct { session dkls.Session - protocolType string // type of protocol (keygen, keyrefresh, quorumchange, sign) - coordinator string // coordinatorPeerID - expiryTime time.Time // when session expires - participants []string // list of participants (from setup message) - stepMu sync.Mutex // mutex to serialize Step() calls (DKLS may not be thread-safe) + protocolType string // type of protocol (keygen, keyrefresh, quorumchange, sign) + coordinator string // coordinatorPeerID + expiryTime time.Time // when session expires + participants []string // list of participants (from setup message) + stepMu sync.Mutex // mutex to serialize Step() calls (DKLS may not be thread-safe) signingReq *common.UnsignedSigningReq // cached from coordinator setup (sign sessions only) } @@ -139,13 +139,22 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s return nil } - // 2. Validate event exists in DB + // 2. Validate sender is coordinator + isCoord, err := sm.coordinator.IsPeerCoordinator(ctx, senderPeerID) + if err != nil { + return fmt.Errorf("failed to check if sender is coordinator: %w", err) + } + if !isCoord { + return fmt.Errorf("sender %s is not the coordinator", senderPeerID) + } + + // 3. Validate event exists in DB event, err := sm.eventStore.GetEvent(msg.EventID) if err != nil { return fmt.Errorf("event %s not found in database: %w", msg.EventID, err) } - // 3. Validate event is CONFIRMED and not expired + // 4. Validate event is CONFIRMED and not expired if event.Status != store.StatusConfirmed { return fmt.Errorf("event %s is not in confirmed status (got %s)", msg.EventID, event.Status) } @@ -157,15 +166,6 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s return fmt.Errorf("event %s has expired (expiry_block_height %d <= current_block %d)", msg.EventID, event.ExpiryBlockHeight, currentBlock) } - // 4. Validate sender is coordinator - isCoord, err := sm.coordinator.IsPeerCoordinator(ctx, senderPeerID) - if err != nil { - return fmt.Errorf("failed to check if sender is coordinator: %w", err) - } - if !isCoord { - return fmt.Errorf("sender %s is not the coordinator", senderPeerID) - } - // 5. Validate participants list matches event protocol requirements if err := sm.validateParticipants(msg.Participants, event); err != nil { return fmt.Errorf("participants validation failed: %w", err) From 6697ba08e9efa777f8615d525546e63578a106aa Mon Sep 17 00:00:00 2001 From: aman035 Date: Fri, 8 May 2026 17:32:11 +0530 Subject: [PATCH 3/3] chore: fix tc --- .../tss/coordinator/coordinator.go | 14 ++- .../tss/sessionmanager/sessionmanager_test.go | 115 +++++++++++++++--- 2 files changed, 109 insertions(+), 20 deletions(-) diff --git a/universalClient/tss/coordinator/coordinator.go b/universalClient/tss/coordinator/coordinator.go index 14aed8d5..94e030d7 100644 --- a/universalClient/tss/coordinator/coordinator.go +++ b/universalClient/tss/coordinator/coordinator.go @@ -19,7 +19,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/chains" "github.com/pushchain/push-chain-node/universalClient/chains/common" - "github.com/pushchain/push-chain-node/universalClient/pushcore" "github.com/pushchain/push-chain-node/universalClient/store" "github.com/pushchain/push-chain-node/universalClient/tss/eventstore" "github.com/pushchain/push-chain-node/universalClient/tss/keyshare" @@ -28,6 +27,15 @@ import ( "github.com/pushchain/push-chain-node/x/uvalidator/types" ) +// PushCoreClient is the subset of pushcore.Client the coordinator depends on. +// Defined as an interface so tests can inject a mock without spinning up a +// real Push Chain RPC endpoint. *pushcore.Client satisfies this interface. +type PushCoreClient interface { + GetLatestBlock(ctx context.Context) (uint64, error) + GetCurrentKey(ctx context.Context) (*utsstypes.TssKey, error) + GetAllUniversalValidators(ctx context.Context) ([]*types.UniversalValidator, error) +} + const ( // PerChainCap is the max in-flight SIGN events per destination chain (default 16; below EVM mempool accountqueue 64). PerChainCap = 16 @@ -47,7 +55,7 @@ type ackState struct { type Coordinator struct { // Dependencies eventStore *eventstore.Store - pushCore *pushcore.Client + pushCore PushCoreClient keyshareManager *keyshare.Manager chains *chains.Chains @@ -76,7 +84,7 @@ type Coordinator struct { // NewCoordinator creates a new coordinator. func NewCoordinator( eventStore *eventstore.Store, - pushCore *pushcore.Client, + pushCore PushCoreClient, keyshareManager *keyshare.Manager, chains *chains.Chains, validatorAddress string, diff --git a/universalClient/tss/sessionmanager/sessionmanager_test.go b/universalClient/tss/sessionmanager/sessionmanager_test.go index cdd5b1b3..285e0586 100644 --- a/universalClient/tss/sessionmanager/sessionmanager_test.go +++ b/universalClient/tss/sessionmanager/sessionmanager_test.go @@ -19,7 +19,6 @@ import ( "github.com/pushchain/push-chain-node/universalClient/chains" "github.com/pushchain/push-chain-node/universalClient/chains/common" "github.com/pushchain/push-chain-node/universalClient/config" - "github.com/pushchain/push-chain-node/universalClient/pushcore" "github.com/pushchain/push-chain-node/universalClient/store" "github.com/pushchain/push-chain-node/universalClient/tss/coordinator" "github.com/pushchain/push-chain-node/universalClient/tss/dkls" @@ -42,6 +41,25 @@ func containsAny(s string, substrings []string) bool { return false } +// mockPushCore is a stub PushCoreClient for tests so the coordinator's +// IsPeerCoordinator path doesn't need a live Push Chain RPC. Returns a fixed +// block height (0 by default) so coordinator-at-block math is deterministic. +type mockPushCore struct { + block uint64 +} + +func (m *mockPushCore) GetLatestBlock(_ context.Context) (uint64, error) { + return m.block, nil +} + +func (m *mockPushCore) GetCurrentKey(_ context.Context) (*utsstypes.TssKey, error) { + return &utsstypes.TssKey{KeyId: "test-key"}, nil +} + +func (m *mockPushCore) GetAllUniversalValidators(_ context.Context) ([]*types.UniversalValidator, error) { + return nil, nil +} + // mockSession is a mock implementation of dkls.Session for testing. type mockSession struct { mock.Mock @@ -73,7 +91,7 @@ func (m *mockSession) Close() { } // setupTestSessionManager creates a test session manager with real coordinator and test dependencies. -func setupTestSessionManager(t *testing.T) (*SessionManager, *coordinator.Coordinator, *eventstore.Store, *keyshare.Manager, *pushcore.Client, *gorm.DB) { +func setupTestSessionManager(t *testing.T) (*SessionManager, *coordinator.Coordinator, *eventstore.Store, *keyshare.Manager, *mockPushCore, *gorm.DB) { db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) require.NoError(t, err) require.NoError(t, db.AutoMigrate(&store.Event{})) @@ -82,8 +100,8 @@ func setupTestSessionManager(t *testing.T) (*SessionManager, *coordinator.Coordi keyshareMgr, err := keyshare.NewManager(t.TempDir(), "test-password") require.NoError(t, err) - // Create a minimal client (will fail on actual calls, but that's OK for most tests) - testClient := &pushcore.Client{} + // Inject a stub PushCoreClient so coordinator RPC paths return canned data. + testClient := &mockPushCore{block: 0} sendFn := func(ctx context.Context, peerID string, data []byte) error { return nil @@ -193,6 +211,8 @@ func TestHandleSetupMessage_Validation(t *testing.T) { require.NoError(t, testDB.Create(&event).Error) t.Run("event not found", func(t *testing.T) { + // peer1 is the coordinator at block 0 (validator1, slot 0), so the + // sender check passes and we reach the DB lookup, which fails. msg := coordinator.Message{ Type: "setup", EventID: "nonexistent", @@ -204,25 +224,18 @@ func TestHandleSetupMessage_Validation(t *testing.T) { }) t.Run("sender not coordinator", func(t *testing.T) { - // peer2 is not the coordinator at block 0 (epoch 0, index 0 = validator1/peer1) - // So sending from peer2 should fail coordinator check + // peer2 is not the coordinator at block 0 (validator1/peer1 is). msg := coordinator.Message{ Type: "setup", EventID: event.EventID, } data, _ := json.Marshal(msg) - err := sm.HandleIncomingMessage(ctx, "peer2", data) // Send from peer2 - // This will fail because GetLatestBlockNum needs real client - // But the error should indicate coordinator check failed + err := sm.HandleIncomingMessage(ctx, "peer2", data) assert.Error(t, err) - // Error will be about no endpoints, but that's expected - assert.Contains(t, err.Error(), "no endpoints") + assert.Contains(t, err.Error(), "is not the coordinator") }) t.Run("invalid participants", func(t *testing.T) { - // Note: This test will also fail on GetLatestBlockNum, but we can test - // the participants validation logic by ensuring the coordinator check passes - // For now, we'll accept that GetLatestBlockNum will fail msg := coordinator.Message{ Type: "setup", EventID: event.EventID, @@ -230,13 +243,81 @@ func TestHandleSetupMessage_Validation(t *testing.T) { } data, _ := json.Marshal(msg) err := sm.HandleIncomingMessage(ctx, "peer1", data) - // Will fail on GetLatestBlockNum, but that's expected assert.Error(t, err) - // Error should be about no endpoints (from GetLatestBlockNum) - assert.Contains(t, err.Error(), "no endpoints") + assert.Contains(t, err.Error(), "participants validation failed") + }) + + t.Run("non-coordinator sender for non-existent event hits coord check first", func(t *testing.T) { + // Locks in the ordering invariant: IsPeerCoordinator runs before the + // DB lookup, so a bogus SETUP from a non-coordinator peer is rejected + // without touching the event store even when the event id is unknown. + msg := coordinator.Message{ + Type: "setup", + EventID: "nonexistent", + } + data, _ := json.Marshal(msg) + err := sm.HandleIncomingMessage(ctx, "peer2", data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "is not the coordinator") + assert.NotContains(t, err.Error(), "not found in database") }) } +func TestHandleSetupMessage_Expiry(t *testing.T) { + sm, _, _, _, _, testDB := setupTestSessionManager(t) + ctx := context.Background() + + t.Run("event with ExpiryBlockHeight <= current block is rejected", func(t *testing.T) { + past := store.Event{ + EventID: "past-event", + BlockHeight: 1, + Type: "KEYGEN", + Status: store.StatusConfirmed, + ExpiryBlockHeight: 1, + } + require.NoError(t, testDB.Create(&past).Error) + + // Bump the coordinator's mock to block 5 so 1 <= 5 fires the guard. + setCoordinatorPushCore(sm.coordinator, &mockPushCore{block: 5}) + + msg := coordinator.Message{Type: "setup", EventID: past.EventID} + data, _ := json.Marshal(msg) + err := sm.HandleIncomingMessage(ctx, "peer1", data) + assert.Error(t, err) + assert.Contains(t, err.Error(), "has expired") + }) + + t.Run("event with ExpiryBlockHeight 0 is treated as no-expiry", func(t *testing.T) { + event := store.Event{ + EventID: "no-expiry-event", + BlockHeight: 1, + Type: "KEYGEN", + Status: store.StatusConfirmed, + ExpiryBlockHeight: 0, + } + require.NoError(t, testDB.Create(&event).Error) + + setCoordinatorPushCore(sm.coordinator, &mockPushCore{block: 0}) + msg := coordinator.Message{Type: "setup", EventID: event.EventID} + data, _ := json.Marshal(msg) + err := sm.HandleIncomingMessage(ctx, "peer1", data) + // A later check (participants) fails, but the expiry branch must not fire. + assert.Error(t, err) + assert.NotContains(t, err.Error(), "has expired") + }) +} + +// setCoordinatorPushCore swaps the coordinator's pushCore field via reflect+unsafe +// so individual tests can override the mock per-case. +func setCoordinatorPushCore(coord *coordinator.Coordinator, client coordinator.PushCoreClient) { + coordValue := reflect.ValueOf(coord).Elem() + field := coordValue.FieldByName("pushCore") + if !field.IsValid() { + return + } + *(*coordinator.PushCoreClient)(unsafe.Pointer(field.UnsafeAddr())) = client +} + func TestHandleStepMessage_Validation(t *testing.T) { sm, _, _, _, _, _ := setupTestSessionManager(t) ctx := context.Background()