[fix]: direct accumulation

This commit is contained in:
mrhaoxx
2026-01-27 04:27:59 +00:00
parent 7b62d826e4
commit 6a1e7c48cb
2 changed files with 52 additions and 58 deletions

View File

@@ -396,6 +396,45 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
self.down_lora_a = down_lora_a.contiguous()
self.down_lora_b = down_lora_b.contiguous()
self.grad_gate_lora_a = (
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
.zero_()
.contiguous()
)
self.grad_gate_lora_b = (
torch.empty(
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
)
.zero_()
.contiguous()
)
self.grad_up_lora_a = (
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
.zero_()
.contiguous()
)
self.grad_up_lora_b = (
torch.empty(
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
)
.zero_()
.contiguous()
)
self.grad_down_lora_a = (
torch.empty(
(self.num_experts, self.lora_rank, self.moe_intermediate_size), dtype=torch.bfloat16, device="cpu"
)
.zero_()
.contiguous()
)
self.grad_down_lora_b = (
torch.empty((self.num_experts, self.hidden_size, self.lora_rank), dtype=torch.bfloat16, device="cpu")
.zero_()
.contiguous()
)
self._lora_initialized = True
# If weights already loaded, update LoRA pointers in C++
@@ -550,12 +589,6 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
# Zero out gradient buffers
buffer.grad_input_cpu.zero_()
buffer.grad_gate_lora_a.zero_()
buffer.grad_gate_lora_b.zero_()
buffer.grad_up_lora_a.zero_()
buffer.grad_up_lora_b.zero_()
buffer.grad_down_lora_a.zero_()
buffer.grad_down_lora_b.zero_()
buffer.grad_weights.zero_()
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
@@ -567,12 +600,12 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
self.moe.backward_task(
buffer.grad_output_cpu.data_ptr(),
buffer.grad_input_cpu.data_ptr(),
buffer.grad_gate_lora_a.data_ptr(),
buffer.grad_gate_lora_b.data_ptr(),
buffer.grad_up_lora_a.data_ptr(),
buffer.grad_up_lora_b.data_ptr(),
buffer.grad_down_lora_a.data_ptr(),
buffer.grad_down_lora_b.data_ptr(),
self.grad_gate_lora_a.data_ptr(),
self.grad_gate_lora_b.data_ptr(),
self.grad_up_lora_a.data_ptr(),
self.grad_up_lora_b.data_ptr(),
self.grad_down_lora_a.data_ptr(),
self.grad_down_lora_b.data_ptr(),
buffer.grad_weights.data_ptr(),
)
)
@@ -581,29 +614,19 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
# Decrease cache depth
self._cache_depth -= 1
# Accumulate LoRA gradients directly to param.grad if lora_params provided
# This avoids clone() by accumulating immediately before buffer is reused
# - First accumulation: param.grad is None, so we clone the buffer
# - Subsequent accumulations: param.grad exists, so we add_ directly (no clone)
if lora_params is not None:
lora_grad_mapping = [
("gate_lora_a", buffer.grad_gate_lora_a),
("gate_lora_b", buffer.grad_gate_lora_b),
("up_lora_a", buffer.grad_up_lora_a),
("up_lora_b", buffer.grad_up_lora_b),
("down_lora_a", buffer.grad_down_lora_a),
("down_lora_b", buffer.grad_down_lora_b),
("gate_lora_a", self.grad_gate_lora_a),
("gate_lora_b", self.grad_gate_lora_b),
("up_lora_a", self.grad_up_lora_a),
("up_lora_b", self.grad_up_lora_b),
("down_lora_a", self.grad_down_lora_a),
("down_lora_b", self.grad_down_lora_b),
]
for param_name, grad_buffer in lora_grad_mapping:
param = lora_params[param_name]
# Convert to param's device and dtype
grad_converted = grad_buffer.to(device=param.device, dtype=param.dtype)
if param.grad is None:
# First accumulation: must clone since buffer will be reused
param.grad = grad_converted.clone()
else:
# Subsequent accumulations: add directly (param.grad is independent)
param.grad.add_(grad_converted)
param.grad = grad_buffer
# Return gradients: if output_device specified, transfer grad_input directly
if output_device is not None: