From 3b2de005932c19e6b5b846df6c40b028b68fce81 Mon Sep 17 00:00:00 2001 From: mrhaoxx Date: Wed, 4 Feb 2026 05:40:47 +0000 Subject: [PATCH] [fix]: wip --- kt-kernel/python/utils/amx_sft.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/kt-kernel/python/utils/amx_sft.py b/kt-kernel/python/utils/amx_sft.py index 09aa1c71..fb4a729f 100644 --- a/kt-kernel/python/utils/amx_sft.py +++ b/kt-kernel/python/utils/amx_sft.py @@ -674,6 +674,15 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): buffer.grad_input_cpu.zero_() buffer.grad_weights.zero_() + # Zero out LoRA gradient buffers (C++ backward accumulates into these) + if self.grad_gate_lora_a is not None: + self.grad_gate_lora_a.zero_() + self.grad_gate_lora_b.zero_() + self.grad_up_lora_a.zero_() + self.grad_up_lora_b.zero_() + self.grad_down_lora_a.zero_() + self.grad_down_lora_b.zero_() + # Synchronize CUDA stream if input was on GPU to ensure data has arrived if input_device.type == "cuda": torch.cuda.synchronize(input_device)