[fix]: wip

This commit is contained in:
mrhaoxx
2026-02-04 05:40:47 +00:00
parent fac81ed147
commit 3b2de00593

View File

@@ -674,6 +674,15 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
buffer.grad_input_cpu.zero_()
buffer.grad_weights.zero_()
# Zero out LoRA gradient buffers (C++ backward accumulates into these)
if self.grad_gate_lora_a is not None:
self.grad_gate_lora_a.zero_()
self.grad_gate_lora_b.zero_()
self.grad_up_lora_a.zero_()
self.grad_up_lora_b.zero_()
self.grad_down_lora_a.zero_()
self.grad_down_lora_b.zero_()
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
if input_device.type == "cuda":
torch.cuda.synchronize(input_device)