mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 23:37:58 +00:00
[fix]: direct accumulation
This commit is contained in:
@@ -396,6 +396,45 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
self.down_lora_a = down_lora_a.contiguous()
|
||||
self.down_lora_b = down_lora_b.contiguous()
|
||||
|
||||
self.grad_gate_lora_a = (
|
||||
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_gate_lora_b = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.grad_up_lora_a = (
|
||||
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_up_lora_b = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.grad_down_lora_a = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.lora_rank, self.moe_intermediate_size), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_down_lora_b = (
|
||||
torch.empty((self.num_experts, self.hidden_size, self.lora_rank), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self._lora_initialized = True
|
||||
|
||||
# If weights already loaded, update LoRA pointers in C++
|
||||
@@ -550,12 +589,6 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
|
||||
# Zero out gradient buffers
|
||||
buffer.grad_input_cpu.zero_()
|
||||
buffer.grad_gate_lora_a.zero_()
|
||||
buffer.grad_gate_lora_b.zero_()
|
||||
buffer.grad_up_lora_a.zero_()
|
||||
buffer.grad_up_lora_b.zero_()
|
||||
buffer.grad_down_lora_a.zero_()
|
||||
buffer.grad_down_lora_b.zero_()
|
||||
buffer.grad_weights.zero_()
|
||||
|
||||
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
||||
@@ -567,12 +600,12 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
self.moe.backward_task(
|
||||
buffer.grad_output_cpu.data_ptr(),
|
||||
buffer.grad_input_cpu.data_ptr(),
|
||||
buffer.grad_gate_lora_a.data_ptr(),
|
||||
buffer.grad_gate_lora_b.data_ptr(),
|
||||
buffer.grad_up_lora_a.data_ptr(),
|
||||
buffer.grad_up_lora_b.data_ptr(),
|
||||
buffer.grad_down_lora_a.data_ptr(),
|
||||
buffer.grad_down_lora_b.data_ptr(),
|
||||
self.grad_gate_lora_a.data_ptr(),
|
||||
self.grad_gate_lora_b.data_ptr(),
|
||||
self.grad_up_lora_a.data_ptr(),
|
||||
self.grad_up_lora_b.data_ptr(),
|
||||
self.grad_down_lora_a.data_ptr(),
|
||||
self.grad_down_lora_b.data_ptr(),
|
||||
buffer.grad_weights.data_ptr(),
|
||||
)
|
||||
)
|
||||
@@ -581,29 +614,19 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
# Decrease cache depth
|
||||
self._cache_depth -= 1
|
||||
|
||||
# Accumulate LoRA gradients directly to param.grad if lora_params provided
|
||||
# This avoids clone() by accumulating immediately before buffer is reused
|
||||
# - First accumulation: param.grad is None, so we clone the buffer
|
||||
# - Subsequent accumulations: param.grad exists, so we add_ directly (no clone)
|
||||
if lora_params is not None:
|
||||
lora_grad_mapping = [
|
||||
("gate_lora_a", buffer.grad_gate_lora_a),
|
||||
("gate_lora_b", buffer.grad_gate_lora_b),
|
||||
("up_lora_a", buffer.grad_up_lora_a),
|
||||
("up_lora_b", buffer.grad_up_lora_b),
|
||||
("down_lora_a", buffer.grad_down_lora_a),
|
||||
("down_lora_b", buffer.grad_down_lora_b),
|
||||
("gate_lora_a", self.grad_gate_lora_a),
|
||||
("gate_lora_b", self.grad_gate_lora_b),
|
||||
("up_lora_a", self.grad_up_lora_a),
|
||||
("up_lora_b", self.grad_up_lora_b),
|
||||
("down_lora_a", self.grad_down_lora_a),
|
||||
("down_lora_b", self.grad_down_lora_b),
|
||||
]
|
||||
for param_name, grad_buffer in lora_grad_mapping:
|
||||
param = lora_params[param_name]
|
||||
# Convert to param's device and dtype
|
||||
grad_converted = grad_buffer.to(device=param.device, dtype=param.dtype)
|
||||
if param.grad is None:
|
||||
# First accumulation: must clone since buffer will be reused
|
||||
param.grad = grad_converted.clone()
|
||||
else:
|
||||
# Subsequent accumulations: add directly (param.grad is independent)
|
||||
param.grad.add_(grad_converted)
|
||||
param.grad = grad_buffer
|
||||
|
||||
# Return gradients: if output_device specified, transfer grad_input directly
|
||||
if output_device is not None:
|
||||
|
||||
Reference in New Issue
Block a user