mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 23:37:58 +00:00
[fix]: wip
This commit is contained in:
@@ -674,6 +674,15 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
||||
buffer.grad_input_cpu.zero_()
|
||||
buffer.grad_weights.zero_()
|
||||
|
||||
# Zero out LoRA gradient buffers (C++ backward accumulates into these)
|
||||
if self.grad_gate_lora_a is not None:
|
||||
self.grad_gate_lora_a.zero_()
|
||||
self.grad_gate_lora_b.zero_()
|
||||
self.grad_up_lora_a.zero_()
|
||||
self.grad_up_lora_b.zero_()
|
||||
self.grad_down_lora_a.zero_()
|
||||
self.grad_down_lora_b.zero_()
|
||||
|
||||
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
||||
if input_device.type == "cuda":
|
||||
torch.cuda.synchronize(input_device)
|
||||
|
||||
Reference in New Issue
Block a user