Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions universalClient/tss/coordinator/coordinator.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (

"github.com/pushchain/push-chain-node/universalClient/chains"
"github.com/pushchain/push-chain-node/universalClient/chains/common"
"github.com/pushchain/push-chain-node/universalClient/pushcore"
"github.com/pushchain/push-chain-node/universalClient/store"
"github.com/pushchain/push-chain-node/universalClient/tss/eventstore"
"github.com/pushchain/push-chain-node/universalClient/tss/keyshare"
Expand All @@ -28,6 +27,15 @@ import (
"github.com/pushchain/push-chain-node/x/uvalidator/types"
)

// PushCoreClient is the subset of pushcore.Client the coordinator depends on.
// Defined as an interface so tests can inject a mock without spinning up a
// real Push Chain RPC endpoint. *pushcore.Client satisfies this interface.
type PushCoreClient interface {
GetLatestBlock(ctx context.Context) (uint64, error)
GetCurrentKey(ctx context.Context) (*utsstypes.TssKey, error)
GetAllUniversalValidators(ctx context.Context) ([]*types.UniversalValidator, error)
}

const (
// PerChainCap is the max in-flight SIGN events per destination chain (default 16; below EVM mempool accountqueue 64).
PerChainCap = 16
Expand All @@ -47,7 +55,7 @@ type ackState struct {
type Coordinator struct {
// Dependencies
eventStore *eventstore.Store
pushCore *pushcore.Client
pushCore PushCoreClient
keyshareManager *keyshare.Manager
chains *chains.Chains

Expand Down Expand Up @@ -76,7 +84,7 @@ type Coordinator struct {
// NewCoordinator creates a new coordinator.
func NewCoordinator(
eventStore *eventstore.Store,
pushCore *pushcore.Client,
pushCore PushCoreClient,
keyshareManager *keyshare.Manager,
chains *chains.Chains,
validatorAddress string,
Expand Down
18 changes: 17 additions & 1 deletion universalClient/tss/networking/libp2p/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ import (
"github.com/pushchain/push-chain-node/universalClient/tss/networking"
)

// MaxFrameSize bounds a single length-prefixed frame on TSS streams. The cap
// rejects oversize length prefixes before allocation so a peer cannot trigger
// large attacker-chosen heap allocations. Sized well above the largest
// observed DKLS Step() + coordinator.Message wrapping for our committee sizes.
const MaxFrameSize = 1 * 1024 * 1024 // 1 MiB

// Network implements networking.Network using libp2p.
type Network struct {
cfg Config
Expand Down Expand Up @@ -224,6 +230,9 @@ func loadIdentity(base64Key string) (crypto.PrivKey, error) {
}

func writeFramed(w io.Writer, data []byte) error {
if len(data) > MaxFrameSize {
return fmt.Errorf("frame size %d exceeds maximum %d", len(data), MaxFrameSize)
}
bw := bufio.NewWriter(w)
if err := binary.Write(bw, binary.BigEndian, uint32(len(data))); err != nil {
return err
Expand All @@ -235,11 +244,18 @@ func writeFramed(w io.Writer, data []byte) error {
}

func readFramed(r io.Reader) ([]byte, error) {
br := bufio.NewReader(r)
// Cap the underlying reader at MaxFrameSize+4 (4 bytes length prefix +
// payload) as defense-in-depth: even if the explicit length check below is
// ever bypassed by a future change, the reader cannot consume more than
// this many bytes from the peer.
br := bufio.NewReader(io.LimitReader(r, int64(MaxFrameSize)+4))
var length uint32
if err := binary.Read(br, binary.BigEndian, &length); err != nil {
return nil, err
}
if length > MaxFrameSize {
return nil, fmt.Errorf("frame size %d exceeds maximum %d", length, MaxFrameSize)
}
buf := make([]byte, length)
if _, err := io.ReadFull(br, buf); err != nil {
return nil, err
Expand Down
80 changes: 80 additions & 0 deletions universalClient/tss/networking/libp2p/network_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package libp2p

import (
"bytes"
"encoding/binary"
"io"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestReadFramed_RoundTrip(t *testing.T) {
payload := []byte("hello tss")
var buf bytes.Buffer
require.NoError(t, writeFramed(&buf, payload))

got, err := readFramed(&buf)
require.NoError(t, err)
assert.Equal(t, payload, got)
}

func TestReadFramed_RejectsOversizeLengthPrefix(t *testing.T) {
// Craft a frame whose length prefix claims more than MaxFrameSize.
// readFramed must reject before allocating MaxFrameSize+1 bytes.
var buf bytes.Buffer
require.NoError(t, binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize+1)))
// No payload bytes follow — readFramed should fail on the length check
// before attempting to read the (non-existent) body.

_, err := readFramed(&buf)
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum")
}

func TestReadFramed_AcceptsAtMaxFrameSize(t *testing.T) {
// Boundary: a frame of exactly MaxFrameSize bytes must be accepted.
// We don't actually allocate 16 MiB in the test buffer; instead we
// validate the length-check path with a reader that returns EOF after
// the length prefix and assert the failure mode is the read error,
// not the size-cap error.
var buf bytes.Buffer
require.NoError(t, binary.Write(&buf, binary.BigEndian, uint32(MaxFrameSize)))

_, err := readFramed(&buf)
require.Error(t, err)
// Should be EOF/UnexpectedEOF on the body read, NOT the size-cap rejection.
assert.NotContains(t, err.Error(), "exceeds maximum")
assert.True(t, err == io.EOF || err == io.ErrUnexpectedEOF, "expected EOF on truncated body, got: %v", err)
}

func TestWriteFramed_RejectsOversizePayload(t *testing.T) {
// writeFramed must symmetric-cap so a misbehaving local sender cannot
// produce a frame that the receiving peer would itself reject. Avoids
// silent protocol drops where the wire format crosses the line.
oversize := make([]byte, MaxFrameSize+1)
var buf bytes.Buffer
err := writeFramed(&buf, oversize)
require.Error(t, err)
assert.Contains(t, err.Error(), "exceeds maximum")
// Buffer must not contain partial data — the check happens before any write.
assert.Equal(t, 0, buf.Len(), "writeFramed must not emit any bytes when rejecting")
}

func TestWriteFramed_AcceptsAtMaxFrameSize(t *testing.T) {
// Boundary: payload of exactly MaxFrameSize must round-trip.
payload := make([]byte, MaxFrameSize)
for i := range payload {
payload[i] = byte(i % 256)
}
var buf bytes.Buffer
require.NoError(t, writeFramed(&buf, payload))

got, err := readFramed(&buf)
require.NoError(t, err)
assert.Equal(t, len(payload), len(got))
assert.Equal(t, payload[0], got[0])
assert.Equal(t, payload[len(payload)-1], got[len(got)-1])
}

32 changes: 16 additions & 16 deletions universalClient/tss/sessionmanager/sessionmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ type SendFunc func(ctx context.Context, peerID string, data []byte) error
// sessionState holds all state for a single session.
type sessionState struct {
session dkls.Session
protocolType string // type of protocol (keygen, keyrefresh, quorumchange, sign)
coordinator string // coordinatorPeerID
expiryTime time.Time // when session expires
participants []string // list of participants (from setup message)
stepMu sync.Mutex // mutex to serialize Step() calls (DKLS may not be thread-safe)
protocolType string // type of protocol (keygen, keyrefresh, quorumchange, sign)
coordinator string // coordinatorPeerID
expiryTime time.Time // when session expires
participants []string // list of participants (from setup message)
stepMu sync.Mutex // mutex to serialize Step() calls (DKLS may not be thread-safe)
signingReq *common.UnsignedSigningReq // cached from coordinator setup (sign sessions only)
}

Expand Down Expand Up @@ -139,13 +139,22 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s
return nil
}

// 2. Validate event exists in DB
// 2. Validate sender is coordinator
isCoord, err := sm.coordinator.IsPeerCoordinator(ctx, senderPeerID)
if err != nil {
return fmt.Errorf("failed to check if sender is coordinator: %w", err)
}
if !isCoord {
return fmt.Errorf("sender %s is not the coordinator", senderPeerID)
}

// 3. Validate event exists in DB
event, err := sm.eventStore.GetEvent(msg.EventID)
if err != nil {
return fmt.Errorf("event %s not found in database: %w", msg.EventID, err)
}

// 3. Validate event is CONFIRMED and not expired
// 4. Validate event is CONFIRMED and not expired
if event.Status != store.StatusConfirmed {
return fmt.Errorf("event %s is not in confirmed status (got %s)", msg.EventID, event.Status)
}
Expand All @@ -157,15 +166,6 @@ func (sm *SessionManager) handleSetupMessage(ctx context.Context, senderPeerID s
return fmt.Errorf("event %s has expired (expiry_block_height %d <= current_block %d)", msg.EventID, event.ExpiryBlockHeight, currentBlock)
}

// 4. Validate sender is coordinator
isCoord, err := sm.coordinator.IsPeerCoordinator(ctx, senderPeerID)
if err != nil {
return fmt.Errorf("failed to check if sender is coordinator: %w", err)
}
if !isCoord {
return fmt.Errorf("sender %s is not the coordinator", senderPeerID)
}

// 5. Validate participants list matches event protocol requirements
if err := sm.validateParticipants(msg.Participants, event); err != nil {
return fmt.Errorf("participants validation failed: %w", err)
Expand Down
Loading
Loading