mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 23:37:58 +00:00
[fix]: direct accumulation
This commit is contained in:
@@ -87,19 +87,6 @@ class KExpertsSFTBuffer:
|
||||
self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory)
|
||||
|
||||
# ========== LoRA gradient buffers (6 total) ==========
|
||||
# Gate LoRA gradients
|
||||
self.grad_gate_lora_a = torch.empty((num_experts, lora_rank, hidden_size), dtype=dtype, device="cpu")
|
||||
self.grad_gate_lora_b = torch.empty((num_experts, moe_intermediate_size, lora_rank), dtype=dtype, device="cpu")
|
||||
|
||||
# Up LoRA gradients
|
||||
self.grad_up_lora_a = torch.empty((num_experts, lora_rank, hidden_size), dtype=dtype, device="cpu")
|
||||
self.grad_up_lora_b = torch.empty((num_experts, moe_intermediate_size, lora_rank), dtype=dtype, device="cpu")
|
||||
|
||||
# Down LoRA gradients
|
||||
self.grad_down_lora_a = torch.empty((num_experts, lora_rank, moe_intermediate_size), dtype=dtype, device="cpu")
|
||||
self.grad_down_lora_b = torch.empty((num_experts, hidden_size, lora_rank), dtype=dtype, device="cpu")
|
||||
|
||||
# Routing weights gradient [qlen, num_experts_per_tok] (FP32)
|
||||
self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu")
|
||||
|
||||
@@ -154,22 +141,6 @@ class KExpertsSFTBuffer:
|
||||
"""Clear all cached buffers."""
|
||||
cls.capture_buffers.clear()
|
||||
|
||||
def get_lora_grads(self) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Get all LoRA gradients as a dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary containing 6 LoRA gradient tensors
|
||||
"""
|
||||
return {
|
||||
"grad_gate_lora_a": self.grad_gate_lora_a,
|
||||
"grad_gate_lora_b": self.grad_gate_lora_b,
|
||||
"grad_up_lora_a": self.grad_up_lora_a,
|
||||
"grad_up_lora_b": self.grad_up_lora_b,
|
||||
"grad_down_lora_a": self.grad_down_lora_a,
|
||||
"grad_down_lora_b": self.grad_down_lora_b,
|
||||
}
|
||||
|
||||
|
||||
class BaseSFTMoEWrapper(_MoEBase, ABC):
|
||||
"""
|
||||
|
||||
@@ -396,6 +396,45 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
self.down_lora_a = down_lora_a.contiguous()
|
||||
self.down_lora_b = down_lora_b.contiguous()
|
||||
|
||||
self.grad_gate_lora_a = (
|
||||
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_gate_lora_b = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.grad_up_lora_a = (
|
||||
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_up_lora_b = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self.grad_down_lora_a = (
|
||||
torch.empty(
|
||||
(self.num_experts, self.lora_rank, self.moe_intermediate_size), dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
self.grad_down_lora_b = (
|
||||
torch.empty((self.num_experts, self.hidden_size, self.lora_rank), dtype=torch.bfloat16, device="cpu")
|
||||
.zero_()
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
self._lora_initialized = True
|
||||
|
||||
# If weights already loaded, update LoRA pointers in C++
|
||||
@@ -550,12 +589,6 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
|
||||
# Zero out gradient buffers
|
||||
buffer.grad_input_cpu.zero_()
|
||||
buffer.grad_gate_lora_a.zero_()
|
||||
buffer.grad_gate_lora_b.zero_()
|
||||
buffer.grad_up_lora_a.zero_()
|
||||
buffer.grad_up_lora_b.zero_()
|
||||
buffer.grad_down_lora_a.zero_()
|
||||
buffer.grad_down_lora_b.zero_()
|
||||
buffer.grad_weights.zero_()
|
||||
|
||||
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
||||
@@ -567,12 +600,12 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
self.moe.backward_task(
|
||||
buffer.grad_output_cpu.data_ptr(),
|
||||
buffer.grad_input_cpu.data_ptr(),
|
||||
buffer.grad_gate_lora_a.data_ptr(),
|
||||
buffer.grad_gate_lora_b.data_ptr(),
|
||||
buffer.grad_up_lora_a.data_ptr(),
|
||||
buffer.grad_up_lora_b.data_ptr(),
|
||||
buffer.grad_down_lora_a.data_ptr(),
|
||||
buffer.grad_down_lora_b.data_ptr(),
|
||||
self.grad_gate_lora_a.data_ptr(),
|
||||
self.grad_gate_lora_b.data_ptr(),
|
||||
self.grad_up_lora_a.data_ptr(),
|
||||
self.grad_up_lora_b.data_ptr(),
|
||||
self.grad_down_lora_a.data_ptr(),
|
||||
self.grad_down_lora_b.data_ptr(),
|
||||
buffer.grad_weights.data_ptr(),
|
||||
)
|
||||
)
|
||||
@@ -581,29 +614,19 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
# Decrease cache depth
|
||||
self._cache_depth -= 1
|
||||
|
||||
# Accumulate LoRA gradients directly to param.grad if lora_params provided
|
||||
# This avoids clone() by accumulating immediately before buffer is reused
|
||||
# - First accumulation: param.grad is None, so we clone the buffer
|
||||
# - Subsequent accumulations: param.grad exists, so we add_ directly (no clone)
|
||||
if lora_params is not None:
|
||||
lora_grad_mapping = [
|
||||
("gate_lora_a", buffer.grad_gate_lora_a),
|
||||
("gate_lora_b", buffer.grad_gate_lora_b),
|
||||
("up_lora_a", buffer.grad_up_lora_a),
|
||||
("up_lora_b", buffer.grad_up_lora_b),
|
||||
("down_lora_a", buffer.grad_down_lora_a),
|
||||
("down_lora_b", buffer.grad_down_lora_b),
|
||||
("gate_lora_a", self.grad_gate_lora_a),
|
||||
("gate_lora_b", self.grad_gate_lora_b),
|
||||
("up_lora_a", self.grad_up_lora_a),
|
||||
("up_lora_b", self.grad_up_lora_b),
|
||||
("down_lora_a", self.grad_down_lora_a),
|
||||
("down_lora_b", self.grad_down_lora_b),
|
||||
]
|
||||
for param_name, grad_buffer in lora_grad_mapping:
|
||||
param = lora_params[param_name]
|
||||
# Convert to param's device and dtype
|
||||
grad_converted = grad_buffer.to(device=param.device, dtype=param.dtype)
|
||||
if param.grad is None:
|
||||
# First accumulation: must clone since buffer will be reused
|
||||
param.grad = grad_converted.clone()
|
||||
else:
|
||||
# Subsequent accumulations: add directly (param.grad is independent)
|
||||
param.grad.add_(grad_converted)
|
||||
param.grad = grad_buffer
|
||||
|
||||
# Return gradients: if output_device specified, transfer grad_input directly
|
||||
if output_device is not None:
|
||||
|
||||
Reference in New Issue
Block a user