From 98dc840e2f677ba5a835aa987e434fa9dbff5e84 Mon Sep 17 00:00:00 2001 From: chen Date: Tue, 28 Apr 2026 16:22:45 +0000 Subject: [PATCH 1/3] refactor: fuse base+LoRA matmuls before collective to fix loss divergence Inline base and LoRA matmuls, add locally, then issue a single AllGather/AllReduce instead of two separate collective ops. The prior two-collective approach caused floating-point divergence in DDP loss. Also fix LoadLoRAWeights to slice sharded tensors by tp_rank when the checkpoint shape differs from the partitioned model shape. --- .../src/nn/lora/lora_parallel_linear.cc | 104 ++++++++---------- infini_train/src/nn/lora/lora_utils.cc | 25 ++++- 2 files changed, 68 insertions(+), 61 deletions(-) diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 760ed3d8..1b160c80 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -126,39 +126,35 @@ LoRAColumnParallelLinear::Forward(const std::vector> &in << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // 1. Compute base output via parent class - auto base_result = ColumnParallelLinear::Forward(input_tensors); - auto base_output = base_result[0]; - - // 2. Compute LoRA output using the SAME input that base module uses - // Match base input path exactly: use direct input if input_is_parallel_ or sequence_parallel_, - // otherwise copy to TP region - auto lora_input = (input_is_parallel_ || sequence_parallel_) - ? input_tensors[0] - : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllGather ops which cause floating-point divergence. + auto input = (input_is_parallel_ || sequence_parallel_) ? input_tensors[0] + : parallel::CopyToTPRegionFunc(input_tensors[0])[0]; if (sequence_parallel_) { - // Base uses GatherFromSPRegionFunc to gather sequence dimension - lora_input = parallel::GatherFromSPRegionFunc(lora_input)[0]; + input = parallel::GatherFromSPRegionFunc(input)[0]; } - // Compute LoRA: lora_A: [rank, in_features], lora_B: [out_per_partition, rank] - auto lora_proj = std::make_shared()->Apply({lora_input, parameters_[kParamLoraAName]})[0]; + // Base matmul (bias folded in when applicable, matching ColumnParallelLinear::Forward) + auto base_shard = std::make_shared()->Apply( + (bias_ && !skip_bias_add_) + ? std::vector>{input, parameters_.at(kParamWeightName), + parameters_[kParamBiasName]} + : std::vector>{input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_A through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_A = parallel::CopyToTPRegionFunc(parameters_[kParamLoraAName])[0]; + auto lora_proj = std::make_shared()->Apply({input, lora_A})[0]; auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - // Match base output layout (gather if base gathers) - if (gather_output_) { - lora_output = parallel::GatherFromTPRegionFunc(lora_output)[0]; - } - - auto scaled_lora = lora_output->Mul(config_.Scaling()); + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); - // 3. Add LoRA contribution to base output - // Both should now have the same sequence dimension - auto output = base_output->Add(scaled_lora); + // Single collective op + auto output = gather_output_ ? parallel::GatherFromTPRegionFunc(combined)[0] : combined; - // Return in same format as base module return skip_bias_add_ - ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} + ? std::vector>{output, bias_ ? parameters_.at(kParamBiasName) : nullptr} : std::vector>{output}; } @@ -321,42 +317,32 @@ LoRARowParallelLinear::Forward(const std::vector> &input << "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training."; if (!merged_) { - // Get effective input - match what base module uses - auto effective_input = input_tensors[0]; - const int64_t in_dim = effective_input->Dims().back(); - - if (!input_is_parallel_) { - // base would scatter; lora must match - effective_input = parallel::ScatterToTPRegionFunc(effective_input)[0]; - CHECK_EQ(effective_input->Dims().back(), in_features_per_partition_); - } else { - // input_is_parallel_=true means caller promised shard input - CHECK_EQ(in_dim, in_features_per_partition_) - << "RowParallel expects sharded input when input_is_parallel_=true. " - << "Got full in_dim=" << in_dim << " (likely upstream gathered TP output)."; + // Inline base + LoRA matmuls, add locally, then single collective op. + // This avoids 2 separate AllReduce ops which cause floating-point divergence. + auto input = input_is_parallel_ ? input_tensors[0] : parallel::ScatterToTPRegionFunc(input_tensors[0])[0]; + + // Base matmul (no bias — RowParallel adds bias AFTER collective) + auto base_shard = std::make_shared()->Apply({input, parameters_.at(kParamWeightName)})[0]; + + // LoRA matmul (local) + // Wrap replicated lora_B through CopyToTPRegion so its gradient gets AllReduced in backward + auto lora_proj = std::make_shared()->Apply({input, parameters_[kParamLoraAName]})[0]; + auto lora_B = parallel::CopyToTPRegionFunc(parameters_[kParamLoraBName])[0]; + auto lora_output = std::make_shared()->Apply({lora_proj, lora_B})[0]; + + // Local add before collective + auto combined = base_shard->Add(lora_output->Mul(config_.Scaling())); + + // Single collective op + auto output = reduce_output_ ? (sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(combined)[0] + : parallel::ReduceFromTPRegionFunc(combined)[0]) + : combined; + + // Bias after collective (matching RowParallelLinear::Forward) + if (bias_ && !skip_bias_add_) { + output = output->Add(parameters_[kParamBiasName]); } - // 1) base output - use effective_input - auto base_result = RowParallelLinear::Forward({effective_input}); - auto base_output = base_result[0]; - - // 2) lora branch uses the SAME effective_input - auto lora_proj - = std::make_shared()->Apply({effective_input, parameters_[kParamLoraAName]})[0]; - auto lora_output = std::make_shared()->Apply({lora_proj, parameters_[kParamLoraBName]})[0]; - - // 3) apply same reduction as base - auto lora_out = lora_output; - if (reduce_output_) { - lora_out = sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(lora_out)[0] - : parallel::ReduceFromTPRegionFunc(lora_out)[0]; - } - - auto scaled_lora = lora_out->Mul(config_.Scaling()); - CHECK_EQ(base_output->NumElements(), scaled_lora->NumElements()); - auto output = base_output->Add(scaled_lora); - - // Return in same format as base module return skip_bias_add_ ? std::vector>{output, bias_ ? parameters_[kParamBiasName] : nullptr} : std::vector>{output}; diff --git a/infini_train/src/nn/lora/lora_utils.cc b/infini_train/src/nn/lora/lora_utils.cc index 7b8f3668..56f5f012 100644 --- a/infini_train/src/nn/lora/lora_utils.cc +++ b/infini_train/src/nn/lora/lora_utils.cc @@ -15,6 +15,7 @@ #include "infini_train/include/nn/lora/lora_parallel_linear.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/modules/module.h" +#include "infini_train/include/nn/parallel/global.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/tensor.h" @@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr model, const std::string &filepath) auto cpu_tensor = std::make_shared(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0)); file.read(reinterpret_cast(cpu_tensor->DataPtr()), num_elements * sizeof(float)); - // Load into model + // Load into model, slicing sharded tensors by tp_rank if shapes differ auto it = model_state_dict.find(name); if (it != model_state_dict.end()) { - it->second->CopyFrom(cpu_tensor); + auto &dst = it->second; + const auto &dst_dims = dst->Dims(); + if (dst_dims == dims) { + dst->CopyFrom(cpu_tensor); + } else { + // Determine which dim is sharded: find first dim where sizes differ + int shard_dim = -1; + for (int d = 0; d < static_cast(dims.size()); ++d) { + if (d < static_cast(dst_dims.size()) && dst_dims[d] != dims[d]) { + shard_dim = d; + break; + } + } + CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name + << " but no differing dim found"; + int tp_size = parallel::global::GetTensorParallelSize(); + int64_t shard_size = dims[shard_dim] / tp_size; + int64_t start = parallel::tp_rank * shard_size; + auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size); + dst->CopyFrom(sliced); + } } else { LOG(WARNING) << "LoRA parameter not found in model: " << name; } From c614ec65824720783f49419c7315f0936117d5da Mon Sep 17 00:00:00 2001 From: chen Date: Wed, 29 Apr 2026 09:03:19 +0000 Subject: [PATCH 2/3] fix: broadcast lora_A init from TP rank 0 to ensure consistent replicated weights --- .../src/nn/lora/lora_parallel_linear.cc | 35 ++++++++++++++----- scripts/run_models_and_profile.bash | 7 ++-- scripts/test_config.json | 34 +++++++++++------- 3 files changed, 53 insertions(+), 23 deletions(-) diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 1b160c80..595ad2ca 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -11,6 +11,7 @@ #include "infini_train/include/nn/init.h" #include "infini_train/include/nn/modules/linear.h" #include "infini_train/include/nn/parallel/global.h" +#include "infini_train/include/nn/parallel/process_group.h" #include "infini_train/include/nn/parallel/tensor_parallel.h" #include "infini_train/include/nn/parallel/utils.h" #include "infini_train/include/tensor.h" @@ -89,22 +90,38 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr(std::vector{config_.rank, in_features_}, DataType::kFLOAT32, device_) ->RequiresGrad(); - if (config_.use_kaiming_a) { - init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + + if (parallel::global::GetTensorParallelSize() > 1) { + const auto global_rank = device_.Rank().GlobalRank(); + auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type()) + ->Get(parallel::GetTensorParallelProcessGroupName(global_rank)); + const int tp_rank = tp_group->GetGroupRank(global_rank); + + // Only TP rank 0 generates random values; others zero-init. + // AllReduce(sum) then broadcasts rank-0's values to all TP ranks. + if (tp_rank == 0) { + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } + } else { + init::Zeros(parameters_[kParamLoraAName]); + } + tp_group->AllReduce(parameters_[kParamLoraAName]); } else { - init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } } - // lora_B: [out_per_partition, rank] - sharded like base weight parameters_[kParamLoraBName] = std::make_shared(std::vector{out_features_per_partition_, config_.rank}, DataType::kFLOAT32, device_) diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 06589904..15d32770 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -154,8 +154,9 @@ run_and_log() { > "$log_path" fi - # Write the current run command to the log - echo "[COMMAND] $cmd" >> "$log_path" + # Write the current run command to the log (expand $LORA_WEIGHTS_DIR) + local expanded_cmd="${cmd//\$LORA_WEIGHTS_DIR/$LORA_WEIGHTS_DIR}" + echo "[COMMAND] $expanded_cmd" >> "$log_path" # Run the command and append both stdout and stderr to the log file if ! eval "$cmd" >> "$log_path" 2>&1; then @@ -267,10 +268,12 @@ for ((id=0; id Date: Wed, 13 May 2026 12:14:47 +0000 Subject: [PATCH 3/3] refactor: add multi-process Scatter overload and use it for LoRA lora_A init Add ProcessGroup::Scatter(tensor, dim, src_group_rank) overload where each process only materializes shards for its own local devices. Use it in LoRARowParallelLinear to replace broadcast+slice, avoiding tp_size-fold communication volume during init. --- .../include/nn/parallel/process_group.h | 17 +- .../src/nn/lora/lora_parallel_linear.cc | 41 +++- infini_train/src/nn/parallel/process_group.cc | 214 +++++++++++++++--- 3 files changed, 223 insertions(+), 49 deletions(-) diff --git a/infini_train/include/nn/parallel/process_group.h b/infini_train/include/nn/parallel/process_group.h index 74bf80c6..aeb67af4 100644 --- a/infini_train/include/nn/parallel/process_group.h +++ b/infini_train/include/nn/parallel/process_group.h @@ -59,14 +59,25 @@ class ProcessGroup { bool async_op = false) const; // Legacy communication APIs (Single-stream) - virtual std::vector> - BroadCast(const std::vector> &input_tensors) const; + // If root_group_rank is -1, infer root from input_tensors[0]'s device (single-process mode). + // In multi-process mode, the caller must pass the source's group rank on every rank. + virtual std::vector> BroadCast(const std::vector> &input_tensors, + int root_group_rank = -1) const; virtual std::vector> ReduceAddCoalesced(const std::vector>> &grads, Device destination) const; + // Single-process / DataParallel form: `devices` enumerates all target devices (must be local + // to this process). Source is inferred from `tensor->GetDevice()` when `src_group_rank` is -1. virtual std::vector> Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const; + std::vector devices, int64_t dim, + int src_group_rank = -1) const; + + // Multi-process-friendly form (TP init etc.): each process only materializes shard(s) for + // its own local device(s) in this group. `tensor` must carry the full shape/dtype on every + // process; data is only read on the src process. + virtual std::vector> Scatter(const std::shared_ptr &tensor, int64_t dim, + int src_group_rank) const; virtual std::shared_ptr Gather(const std::vector> &tensors, Device destination, int64_t dim) const; diff --git a/infini_train/src/nn/lora/lora_parallel_linear.cc b/infini_train/src/nn/lora/lora_parallel_linear.cc index 595ad2ca..195e734b 100644 --- a/infini_train/src/nn/lora/lora_parallel_linear.cc +++ b/infini_train/src/nn/lora/lora_parallel_linear.cc @@ -102,18 +102,16 @@ void LoRAColumnParallelLinear::InitLoRAWeights() { ->Get(parallel::GetTensorParallelProcessGroupName(global_rank)); const int tp_rank = tp_group->GetGroupRank(global_rank); - // Only TP rank 0 generates random values; others zero-init. - // AllReduce(sum) then broadcasts rank-0's values to all TP ranks. + // TP rank 0 generates random values; Broadcast replicates to other ranks. if (tp_rank == 0) { if (config_.use_kaiming_a) { init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); } else { init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); } - } else { - init::Zeros(parameters_[kParamLoraAName]); } - tp_group->AllReduce(parameters_[kParamLoraAName]); + auto broadcasted = tp_group->BroadCast({parameters_[kParamLoraAName]}, /*root_group_rank=*/0); + parameters_[kParamLoraAName]->CopyFrom(broadcasted[0]); } else { if (config_.use_kaiming_a) { init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); @@ -303,17 +301,42 @@ void LoRARowParallelLinear::InitLoRAWeights() { // lora_B: [out_features, rank] - replicated // lora_A: [rank, in_features_per_partition] + // TP rank 0 generates full [lora_rank, in_features], broadcasts to all TP ranks, + // then each rank slices its own shard along dim=1. parameters_[kParamLoraAName] = std::make_shared(std::vector{config_.rank, in_features_per_partition_}, DataType::kFLOAT32, device_) ->RequiresGrad(); - if (config_.use_kaiming_a) { - init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + + if (parallel::global::GetTensorParallelSize() > 1) { + const auto global_rank = device_.Rank().GlobalRank(); + auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type()) + ->Get(parallel::GetTensorParallelProcessGroupName(global_rank)); + const int tp_rank = tp_group->GetGroupRank(global_rank); + const int tp_size = parallel::global::GetTensorParallelSize(); + + // TP rank 0 generates full [lora_rank, in_features]; scatter shards along dim=1 to all ranks. + // Non-src processes pass a tensor carrying only shape/dtype (contents unread). + auto full_lora_A = std::make_shared( + std::vector{config_.rank, in_features_per_partition_ * tp_size}, DataType::kFLOAT32, device_); + if (tp_rank == 0) { + if (config_.use_kaiming_a) { + init::KaimingUniform(full_lora_A, config_.kaiming_a_param); + } else { + init::Normal(full_lora_A, 0.0f, 0.02f); + } + } + auto shards = tp_group->Scatter(full_lora_A, /*dim=*/1, /*src_group_rank=*/0); + parameters_[kParamLoraAName]->CopyFrom(shards[0]); } else { - init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + if (config_.use_kaiming_a) { + init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param); + } else { + init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f); + } } - // lora_B: [out_features, rank] + // lora_B: [out_features, rank] - replicated, zeros parameters_[kParamLoraBName] = std::make_shared(std::vector{out_features_, config_.rank}, DataType::kFLOAT32, device_) ->RequiresGrad(); diff --git a/infini_train/src/nn/parallel/process_group.cc b/infini_train/src/nn/parallel/process_group.cc index 3c4c4910..6b5842ad 100644 --- a/infini_train/src/nn/parallel/process_group.cc +++ b/infini_train/src/nn/parallel/process_group.cc @@ -248,39 +248,40 @@ std::shared_ptr ProcessGroup::Recv(std::vector> te } } -std::vector> -ProcessGroup::BroadCast(const std::vector> &input_tensors) const { +std::vector> ProcessGroup::BroadCast(const std::vector> &input_tensors, + int root_group_rank) const { std::vector> outputs; std::vector streams; std::vector comms; - std::vector devices; - CHECK_EQ(world_size_, comms_.size()); - for (size_t i = 0; i < world_size_; ++i) { - auto device = devices_[i]; + // Only iterate over this process's devices (in single-process mode this equals world_size_; + // in multi-process mode it is a strict subset). + for (const auto &device : devices_) { for (const auto &input_tensor : input_tensors) { outputs.push_back(std::make_shared(input_tensor->Dims(), input_tensor->Dtype(), device)); } - devices.push_back(device); streams.push_back(runtime_impl_->GetStream(device)); comms.push_back(device_comm_map_.at(device.index())); } - int root = -1; - for (size_t i = 0; i < devices.size(); ++i) { - if (devices[i] == input_tensors[0]->GetDevice()) { - root = static_cast(i); - break; - } + // Determine NCCL root (= group rank of the source). In single-process mode the caller may + // omit it and we infer from input_tensors[0]->GetDevice(); in multi-process mode the source + // may not be on this process, so the caller must provide the group rank explicitly. + int root = root_group_rank; + if (root < 0) { + auto it = global_group_rank_map_.find(input_tensors[0]->GetDevice().Rank().GlobalRank()); + CHECK(it != global_group_rank_map_.end()) + << "BroadCast: root device not found in group and root_group_rank was not provided"; + root = it->second; } - CHECK_NE(root, -1) << "Root not found in input devices"; - core::CclGroupGuard ccl_group_guard(devices[0].type()); - for (size_t i = 0; i < devices.size(); ++i) { - core::DeviceGuard guard(devices[i]); + core::CclGroupGuard ccl_group_guard(devices_[0].type()); + for (size_t i = 0; i < devices_.size(); ++i) { + core::DeviceGuard guard(devices_[i]); + const int local_group_rank = global_group_rank_map_.at(devices_[i].Rank().GlobalRank()); for (size_t j = 0; j < input_tensors.size(); ++j) { const auto &input_tensor = input_tensors[j]; - const void *send_buffer = (static_cast(i) == root ? input_tensor->DataPtr() : nullptr); + const void *send_buffer = (local_group_rank == root ? input_tensor->DataPtr() : nullptr); ccl_impl_->Broadcast(send_buffer, outputs[i * input_tensors.size() + j]->DataPtr(), input_tensor->NumElements(), input_tensor->Dtype(), root, comms[i], streams[i]); } @@ -330,30 +331,169 @@ ProcessGroup::ReduceAddCoalesced(const std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, - std::vector devices, int64_t dim) const { + std::vector devices, int64_t dim, + int src_group_rank) const { + CHECK_EQ(devices.size(), static_cast(world_size_)) << "Scatter expects one device per group rank"; + CHECK_GT(devices.size(), 0); + CHECK(tensor != nullptr) << "Scatter: tensor carrying full shape/dtype must be provided on every process"; + + // Resolve src rank: explicit overrides inference from tensor device. + int src_rank = src_group_rank; + if (src_rank < 0) { + for (size_t i = 0; i < devices.size(); ++i) { + if (tensor->GetDevice() == devices[i]) { + src_rank = static_cast(i); + break; + } + } + CHECK_NE(src_rank, -1) << "Source device not found in input devices"; + } + CHECK_GE(src_rank, 0); + CHECK_LT(src_rank, world_size_); + + // Identify local group ranks (in the same order as devices_). + std::vector local_group_ranks; + local_group_ranks.reserve(devices_.size()); + for (const auto &d : devices_) { local_group_ranks.push_back(global_group_rank_map_.at(d.Rank().GlobalRank())); } + const auto src_local_it = std::find(local_group_ranks.begin(), local_group_ranks.end(), src_rank); + const bool src_is_local = src_local_it != local_group_ranks.end(); + + // Source splits only when it owns the full tensor. Shard shape is identical for all ranks + // when the dim is evenly divisible; we rely on that for preallocation on non-src processes. + CHECK_EQ(tensor->Dims()[dim] % static_cast(devices.size()), 0) + << "Scatter: dim size must be divisible by world size"; + const int64_t shard_size = tensor->Dims()[dim] / static_cast(devices.size()); + std::vector> split_tensors; + if (src_is_local) { + split_tensors = tensor->Split(shard_size, dim); + CHECK_EQ(split_tensors.size(), devices.size()); + } + + std::vector shard_dims = tensor->Dims(); + shard_dims[dim] = shard_size; + const DataType shard_dtype = tensor->Dtype(); + + // Preallocate output shards for this process's local devices. std::vector> outputs; - auto split_tensors = tensor->Split(tensor->Dims()[dim] / devices.size(), dim); - std::vector streams; - std::vector comms; - int src_rank = -1; + outputs.reserve(devices_.size()); + for (const auto &d : devices_) { outputs.push_back(std::make_shared(shard_dims, shard_dtype, d)); } - for (size_t i = 0; i < devices.size(); ++i) { - if (tensor->GetDevice() == devices[i]) { - src_rank = static_cast(i); + // Single-process mode: all devices live here, keep the symmetric Send/Recv loop for clarity. + if (global::GetNnodes() == 1 && global::GetNprocPerNode() == 1) { + std::vector streams; + std::vector comms; + streams.reserve(devices.size()); + comms.reserve(devices.size()); + for (const auto &d : devices) { + streams.push_back(runtime_impl_->GetStream(d)); + comms.push_back(device_comm_map_.at(d.index())); } - outputs.push_back(std::make_shared(split_tensors[i]->Dims(), split_tensors[i]->Dtype(), devices[i])); - streams.push_back(runtime_impl_->GetStream(devices[i])); - comms.push_back(device_comm_map_.at(devices[i].index())); + core::CclGroupGuard ccl_group_guard(devices[0].type()); + for (size_t i = 0; i < devices.size(); ++i) { + core::DeviceGuard guard(devices[i]); + ccl_impl_->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), shard_dtype, + static_cast(i), comms[src_rank], streams[src_rank]); + ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comms[i], + streams[i]); + } + return outputs; } - CHECK_NE(src_rank, -1) << "Source device not found in input devices"; - core::CclGroupGuard ccl_group_guard(devices[0].type()); - for (size_t i = 0; i < devices.size(); ++i) { - core::DeviceGuard guard(devices[i]); - ccl_impl_->Send(split_tensors[i]->DataPtr(), split_tensors[i]->NumElements(), tensor->Dtype(), i, - comms[src_rank], streams[src_rank]); - ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), tensor->Dtype(), src_rank, comms[i], - streams[i]); + // Multi-process mode: each process handles only its local device(s). + core::CclGroupGuard ccl_group_guard(devices_[0].type()); + + // Src issues a Send to every non-src group rank (including group ranks hosted in other processes). + if (src_is_local) { + const size_t src_local_idx = static_cast(src_local_it - local_group_ranks.begin()); + const auto &src_device = devices_[src_local_idx]; + core::DeviceGuard guard(src_device); + auto *stream = runtime_impl_->GetStream(src_device); + auto *comm = device_comm_map_.at(src_device.index()); + for (int dst = 0; dst < world_size_; ++dst) { + if (dst == src_rank) { + continue; + } + ccl_impl_->Send(split_tensors[dst]->DataPtr(), split_tensors[dst]->NumElements(), shard_dtype, dst, comm, + stream); + } + } + + // Every local device posts either a local copy (if it is src) or a Recv from src. + for (size_t i = 0; i < devices_.size(); ++i) { + const auto &local_device = devices_[i]; + const int local_rank = local_group_ranks[i]; + if (src_is_local && local_rank == src_rank) { + outputs[i]->CopyFrom(split_tensors[src_rank]); + continue; + } + core::DeviceGuard guard(local_device); + auto *stream = runtime_impl_->GetStream(local_device); + auto *comm = device_comm_map_.at(local_device.index()); + ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comm, stream); + } + return outputs; +} + +std::vector> ProcessGroup::Scatter(const std::shared_ptr &tensor, int64_t dim, + int src_group_rank) const { + CHECK(tensor != nullptr) << "Scatter: tensor carrying full shape/dtype must be provided on every process"; + CHECK_GE(src_group_rank, 0); + CHECK_LT(src_group_rank, world_size_); + CHECK_GT(devices_.size(), 0); + const int src_rank = src_group_rank; + + // Identify local group ranks (in the same order as devices_). + std::vector local_group_ranks; + local_group_ranks.reserve(devices_.size()); + for (const auto &d : devices_) { local_group_ranks.push_back(global_group_rank_map_.at(d.Rank().GlobalRank())); } + const auto src_local_it = std::find(local_group_ranks.begin(), local_group_ranks.end(), src_rank); + const bool src_is_local = src_local_it != local_group_ranks.end(); + + CHECK_EQ(tensor->Dims()[dim] % static_cast(world_size_), 0) + << "Scatter: dim size must be divisible by world size"; + const int64_t shard_size = tensor->Dims()[dim] / static_cast(world_size_); + std::vector> split_tensors; + if (src_is_local) { + split_tensors = tensor->Split(shard_size, dim); + CHECK_EQ(split_tensors.size(), static_cast(world_size_)); + } + + std::vector shard_dims = tensor->Dims(); + shard_dims[dim] = shard_size; + const DataType shard_dtype = tensor->Dtype(); + + std::vector> outputs; + outputs.reserve(devices_.size()); + for (const auto &d : devices_) { outputs.push_back(std::make_shared(shard_dims, shard_dtype, d)); } + + core::CclGroupGuard ccl_group_guard(devices_[0].type()); + + if (src_is_local) { + const size_t src_local_idx = static_cast(src_local_it - local_group_ranks.begin()); + const auto &src_device = devices_[src_local_idx]; + core::DeviceGuard guard(src_device); + auto *stream = runtime_impl_->GetStream(src_device); + auto *comm = device_comm_map_.at(src_device.index()); + for (int dst = 0; dst < world_size_; ++dst) { + if (dst == src_rank) { + continue; + } + ccl_impl_->Send(split_tensors[dst]->DataPtr(), split_tensors[dst]->NumElements(), shard_dtype, dst, comm, + stream); + } + } + + for (size_t i = 0; i < devices_.size(); ++i) { + const auto &local_device = devices_[i]; + const int local_rank = local_group_ranks[i]; + if (src_is_local && local_rank == src_rank) { + outputs[i]->CopyFrom(split_tensors[src_rank]); + continue; + } + core::DeviceGuard guard(local_device); + auto *stream = runtime_impl_->GetStream(local_device); + auto *comm = device_comm_map_.at(local_device.index()); + ccl_impl_->Recv(outputs[i]->DataPtr(), outputs[i]->NumElements(), shard_dtype, src_rank, comm, stream); } return outputs; }