Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions infini_train/include/nn/parallel/process_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,25 @@ class ProcessGroup {
bool async_op = false) const;

// Legacy communication APIs (Single-stream)
virtual std::vector<std::shared_ptr<Tensor>>
BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors) const;
// If root_group_rank is -1, infer root from input_tensors[0]'s device (single-process mode).
// In multi-process mode, the caller must pass the source's group rank on every rank.
virtual std::vector<std::shared_ptr<Tensor>> BroadCast(const std::vector<std::shared_ptr<Tensor>> &input_tensors,
int root_group_rank = -1) const;

virtual std::vector<std::shared_ptr<Tensor>>
ReduceAddCoalesced(const std::vector<std::vector<std::shared_ptr<Tensor>>> &grads, Device destination) const;

// Single-process / DataParallel form: `devices` enumerates all target devices (must be local
// to this process). Source is inferred from `tensor->GetDevice()` when `src_group_rank` is -1.
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor,
std::vector<Device> devices, int64_t dim) const;
std::vector<Device> devices, int64_t dim,
int src_group_rank = -1) const;

// Multi-process-friendly form (TP init etc.): each process only materializes shard(s) for
// its own local device(s) in this group. `tensor` must carry the full shape/dtype on every
// process; data is only read on the src process.
virtual std::vector<std::shared_ptr<Tensor>> Scatter(const std::shared_ptr<Tensor> &tensor, int64_t dim,
int src_group_rank) const;

virtual std::shared_ptr<Tensor> Gather(const std::vector<std::shared_ptr<Tensor>> &tensors, Device destination,
int64_t dim) const;
Expand Down
170 changes: 98 additions & 72 deletions infini_train/src/nn/lora/lora_parallel_linear.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "infini_train/include/nn/init.h"
#include "infini_train/include/nn/modules/linear.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/process_group.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/nn/parallel/utils.h"
#include "infini_train/include/tensor.h"
Expand Down Expand Up @@ -89,22 +90,36 @@ LoRAColumnParallelLinear::LoRAColumnParallelLinear(std::shared_ptr<parallel::Col
}

void LoRAColumnParallelLinear::InitLoRAWeights() {
// LoRA weights stored directly in parameters_
// Following PEFT pattern conceptually:
// lora_A: [rank, in_features] - replicated
// lora_A: [rank, in_features] - replicated across TP ranks
// lora_B: [out_features_per_partition, rank] - sharded like base weight

// lora_A: [rank, in_features]
parameters_[kParamLoraAName]
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_}, DataType::kFLOAT32, device_)
->RequiresGrad();
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);

if (parallel::global::GetTensorParallelSize() > 1) {
const auto global_rank = device_.Rank().GlobalRank();
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
const int tp_rank = tp_group->GetGroupRank(global_rank);

// TP rank 0 generates random values; Broadcast replicates to other ranks.
if (tp_rank == 0) {
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}
auto broadcasted = tp_group->BroadCast({parameters_[kParamLoraAName]}, /*root_group_rank=*/0);
parameters_[kParamLoraAName]->CopyFrom(broadcasted[0]);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}

// lora_B: [out_per_partition, rank] - sharded like base weight
parameters_[kParamLoraBName]
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_per_partition_, config_.rank}, DataType::kFLOAT32,
device_)
Expand All @@ -126,39 +141,35 @@ LoRAColumnParallelLinear::Forward(const std::vector<std::shared_ptr<Tensor>> &in
<< "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training.";

if (!merged_) {
// 1. Compute base output via parent class
auto base_result = ColumnParallelLinear::Forward(input_tensors);
auto base_output = base_result[0];

// 2. Compute LoRA output using the SAME input that base module uses
// Match base input path exactly: use direct input if input_is_parallel_ or sequence_parallel_,
// otherwise copy to TP region
auto lora_input = (input_is_parallel_ || sequence_parallel_)
? input_tensors[0]
: parallel::CopyToTPRegionFunc(input_tensors[0])[0];
// Inline base + LoRA matmuls, add locally, then single collective op.
// This avoids 2 separate AllGather ops which cause floating-point divergence.
auto input = (input_is_parallel_ || sequence_parallel_) ? input_tensors[0]
: parallel::CopyToTPRegionFunc(input_tensors[0])[0];
if (sequence_parallel_) {
// Base uses GatherFromSPRegionFunc to gather sequence dimension
lora_input = parallel::GatherFromSPRegionFunc(lora_input)[0];
input = parallel::GatherFromSPRegionFunc(input)[0];
}

// Compute LoRA: lora_A: [rank, in_features], lora_B: [out_per_partition, rank]
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({lora_input, parameters_[kParamLoraAName]})[0];
// Base matmul (bias folded in when applicable, matching ColumnParallelLinear::Forward)
auto base_shard = std::make_shared<autograd::Linear>()->Apply(
(bias_ && !skip_bias_add_)
? std::vector<std::shared_ptr<Tensor>>{input, parameters_.at(kParamWeightName),
parameters_[kParamBiasName]}
: std::vector<std::shared_ptr<Tensor>>{input, parameters_.at(kParamWeightName)})[0];

// LoRA matmul (local)
// Wrap replicated lora_A through CopyToTPRegion so its gradient gets AllReduced in backward
auto lora_A = parallel::CopyToTPRegionFunc(parameters_[kParamLoraAName])[0];
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({input, lora_A})[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, parameters_[kParamLoraBName]})[0];

// Match base output layout (gather if base gathers)
if (gather_output_) {
lora_output = parallel::GatherFromTPRegionFunc(lora_output)[0];
}

auto scaled_lora = lora_output->Mul(config_.Scaling());
// Local add before collective
auto combined = base_shard->Add(lora_output->Mul(config_.Scaling()));

// 3. Add LoRA contribution to base output
// Both should now have the same sequence dimension
auto output = base_output->Add(scaled_lora);
// Single collective op
auto output = gather_output_ ? parallel::GatherFromTPRegionFunc(combined)[0] : combined;

// Return in same format as base module
return skip_bias_add_
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_[kParamBiasName] : nullptr}
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_.at(kParamBiasName) : nullptr}
: std::vector<std::shared_ptr<Tensor>>{output};
}

Expand Down Expand Up @@ -290,17 +301,42 @@ void LoRARowParallelLinear::InitLoRAWeights() {
// lora_B: [out_features, rank] - replicated

// lora_A: [rank, in_features_per_partition]
// TP rank 0 generates full [lora_rank, in_features], broadcasts to all TP ranks,
// then each rank slices its own shard along dim=1.
parameters_[kParamLoraAName]
= std::make_shared<Tensor>(std::vector<int64_t>{config_.rank, in_features_per_partition_}, DataType::kFLOAT32,
device_)
->RequiresGrad();
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);

if (parallel::global::GetTensorParallelSize() > 1) {
const auto global_rank = device_.Rank().GlobalRank();
auto *tp_group = parallel::ProcessGroupFactory::Instance(device_.type())
->Get(parallel::GetTensorParallelProcessGroupName(global_rank));
const int tp_rank = tp_group->GetGroupRank(global_rank);
const int tp_size = parallel::global::GetTensorParallelSize();

// TP rank 0 generates full [lora_rank, in_features]; scatter shards along dim=1 to all ranks.
// Non-src processes pass a tensor carrying only shape/dtype (contents unread).
auto full_lora_A = std::make_shared<Tensor>(
std::vector<int64_t>{config_.rank, in_features_per_partition_ * tp_size}, DataType::kFLOAT32, device_);
if (tp_rank == 0) {
if (config_.use_kaiming_a) {
init::KaimingUniform(full_lora_A, config_.kaiming_a_param);
} else {
init::Normal(full_lora_A, 0.0f, 0.02f);
}
}
auto shards = tp_group->Scatter(full_lora_A, /*dim=*/1, /*src_group_rank=*/0);
parameters_[kParamLoraAName]->CopyFrom(shards[0]);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
if (config_.use_kaiming_a) {
init::KaimingUniform(parameters_[kParamLoraAName], config_.kaiming_a_param);
} else {
init::Normal(parameters_[kParamLoraAName], 0.0f, 0.02f);
}
}

// lora_B: [out_features, rank]
// lora_B: [out_features, rank] - replicated, zeros
parameters_[kParamLoraBName]
= std::make_shared<Tensor>(std::vector<int64_t>{out_features_, config_.rank}, DataType::kFLOAT32, device_)
->RequiresGrad();
Expand All @@ -321,42 +357,32 @@ LoRARowParallelLinear::Forward(const std::vector<std::shared_ptr<Tensor>> &input
<< "Forward() on merged LoRA with requires_grad=true. Call UnmergeWeights() before training.";

if (!merged_) {
// Get effective input - match what base module uses
auto effective_input = input_tensors[0];
const int64_t in_dim = effective_input->Dims().back();

if (!input_is_parallel_) {
// base would scatter; lora must match
effective_input = parallel::ScatterToTPRegionFunc(effective_input)[0];
CHECK_EQ(effective_input->Dims().back(), in_features_per_partition_);
} else {
// input_is_parallel_=true means caller promised shard input
CHECK_EQ(in_dim, in_features_per_partition_)
<< "RowParallel expects sharded input when input_is_parallel_=true. "
<< "Got full in_dim=" << in_dim << " (likely upstream gathered TP output).";
// Inline base + LoRA matmuls, add locally, then single collective op.
// This avoids 2 separate AllReduce ops which cause floating-point divergence.
auto input = input_is_parallel_ ? input_tensors[0] : parallel::ScatterToTPRegionFunc(input_tensors[0])[0];

// Base matmul (no bias — RowParallel adds bias AFTER collective)
auto base_shard = std::make_shared<autograd::Linear>()->Apply({input, parameters_.at(kParamWeightName)})[0];

// LoRA matmul (local)
// Wrap replicated lora_B through CopyToTPRegion so its gradient gets AllReduced in backward
auto lora_proj = std::make_shared<autograd::Linear>()->Apply({input, parameters_[kParamLoraAName]})[0];
auto lora_B = parallel::CopyToTPRegionFunc(parameters_[kParamLoraBName])[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, lora_B})[0];

// Local add before collective
auto combined = base_shard->Add(lora_output->Mul(config_.Scaling()));

// Single collective op
auto output = reduce_output_ ? (sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(combined)[0]
: parallel::ReduceFromTPRegionFunc(combined)[0])
: combined;

// Bias after collective (matching RowParallelLinear::Forward)
if (bias_ && !skip_bias_add_) {
output = output->Add(parameters_[kParamBiasName]);
}

// 1) base output - use effective_input
auto base_result = RowParallelLinear::Forward({effective_input});
auto base_output = base_result[0];

// 2) lora branch uses the SAME effective_input
auto lora_proj
= std::make_shared<autograd::Linear>()->Apply({effective_input, parameters_[kParamLoraAName]})[0];
auto lora_output = std::make_shared<autograd::Linear>()->Apply({lora_proj, parameters_[kParamLoraBName]})[0];

// 3) apply same reduction as base
auto lora_out = lora_output;
if (reduce_output_) {
lora_out = sequence_parallel_ ? parallel::ReduceScatterToSPRegionFunc(lora_out)[0]
: parallel::ReduceFromTPRegionFunc(lora_out)[0];
}

auto scaled_lora = lora_out->Mul(config_.Scaling());
CHECK_EQ(base_output->NumElements(), scaled_lora->NumElements());
auto output = base_output->Add(scaled_lora);

// Return in same format as base module
return skip_bias_add_
? std::vector<std::shared_ptr<Tensor>>{output, bias_ ? parameters_[kParamBiasName] : nullptr}
: std::vector<std::shared_ptr<Tensor>>{output};
Expand Down
25 changes: 23 additions & 2 deletions infini_train/src/nn/lora/lora_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "infini_train/include/nn/lora/lora_parallel_linear.h"
#include "infini_train/include/nn/modules/linear.h"
#include "infini_train/include/nn/modules/module.h"
#include "infini_train/include/nn/parallel/global.h"
#include "infini_train/include/nn/parallel/tensor_parallel.h"
#include "infini_train/include/tensor.h"

Expand Down Expand Up @@ -392,10 +393,30 @@ void LoadLoRAWeights(std::shared_ptr<Module> model, const std::string &filepath)
auto cpu_tensor = std::make_shared<Tensor>(dims, DataType::kFLOAT32, Device(Device::DeviceType::kCPU, 0));
file.read(reinterpret_cast<char *>(cpu_tensor->DataPtr()), num_elements * sizeof(float));

// Load into model
// Load into model, slicing sharded tensors by tp_rank if shapes differ
auto it = model_state_dict.find(name);
if (it != model_state_dict.end()) {
it->second->CopyFrom(cpu_tensor);
auto &dst = it->second;
const auto &dst_dims = dst->Dims();
if (dst_dims == dims) {
dst->CopyFrom(cpu_tensor);
} else {
// Determine which dim is sharded: find first dim where sizes differ
int shard_dim = -1;
for (int d = 0; d < static_cast<int>(dims.size()); ++d) {
if (d < static_cast<int>(dst_dims.size()) && dst_dims[d] != dims[d]) {
shard_dim = d;
break;
}
}
CHECK(shard_dim >= 0) << "LoadLoRAWeights: shape mismatch for " << name
<< " but no differing dim found";
int tp_size = parallel::global::GetTensorParallelSize();
int64_t shard_size = dims[shard_dim] / tp_size;
int64_t start = parallel::tp_rank * shard_size;
auto sliced = cpu_tensor->Slice(shard_dim, start, start + shard_size);
dst->CopyFrom(sliced);
}
} else {
LOG(WARNING) << "LoRA parameter not found in model: " << name;
}
Expand Down
Loading
Loading