diff --git a/kt-kernel/python/experts_sft.py b/kt-kernel/python/experts_sft.py index b7a5b6b6..654f296e 100644 --- a/kt-kernel/python/experts_sft.py +++ b/kt-kernel/python/experts_sft.py @@ -87,19 +87,6 @@ class KExpertsSFTBuffer: self.grad_output_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) self.grad_input_cpu = torch.empty((qlen, hidden_size), dtype=dtype, device="cpu", pin_memory=pin_memory) - # ========== LoRA gradient buffers (6 total) ========== - # Gate LoRA gradients - self.grad_gate_lora_a = torch.empty((num_experts, lora_rank, hidden_size), dtype=dtype, device="cpu") - self.grad_gate_lora_b = torch.empty((num_experts, moe_intermediate_size, lora_rank), dtype=dtype, device="cpu") - - # Up LoRA gradients - self.grad_up_lora_a = torch.empty((num_experts, lora_rank, hidden_size), dtype=dtype, device="cpu") - self.grad_up_lora_b = torch.empty((num_experts, moe_intermediate_size, lora_rank), dtype=dtype, device="cpu") - - # Down LoRA gradients - self.grad_down_lora_a = torch.empty((num_experts, lora_rank, moe_intermediate_size), dtype=dtype, device="cpu") - self.grad_down_lora_b = torch.empty((num_experts, hidden_size, lora_rank), dtype=dtype, device="cpu") - # Routing weights gradient [qlen, num_experts_per_tok] (FP32) self.grad_weights = torch.empty((qlen, num_experts_per_tok), dtype=torch.float32, device="cpu") @@ -154,22 +141,6 @@ class KExpertsSFTBuffer: """Clear all cached buffers.""" cls.capture_buffers.clear() - def get_lora_grads(self) -> Dict[str, torch.Tensor]: - """ - Get all LoRA gradients as a dictionary. - - Returns: - Dictionary containing 6 LoRA gradient tensors - """ - return { - "grad_gate_lora_a": self.grad_gate_lora_a, - "grad_gate_lora_b": self.grad_gate_lora_b, - "grad_up_lora_a": self.grad_up_lora_a, - "grad_up_lora_b": self.grad_up_lora_b, - "grad_down_lora_a": self.grad_down_lora_a, - "grad_down_lora_b": self.grad_down_lora_b, - } - class BaseSFTMoEWrapper(_MoEBase, ABC): """ diff --git a/kt-kernel/python/utils/amx_sft.py b/kt-kernel/python/utils/amx_sft.py index 4361c1a1..6f7829d3 100644 --- a/kt-kernel/python/utils/amx_sft.py +++ b/kt-kernel/python/utils/amx_sft.py @@ -396,6 +396,45 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): self.down_lora_a = down_lora_a.contiguous() self.down_lora_b = down_lora_b.contiguous() + self.grad_gate_lora_a = ( + torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu") + .zero_() + .contiguous() + ) + self.grad_gate_lora_b = ( + torch.empty( + (self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu" + ) + .zero_() + .contiguous() + ) + + self.grad_up_lora_a = ( + torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu") + .zero_() + .contiguous() + ) + self.grad_up_lora_b = ( + torch.empty( + (self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu" + ) + .zero_() + .contiguous() + ) + + self.grad_down_lora_a = ( + torch.empty( + (self.num_experts, self.lora_rank, self.moe_intermediate_size), dtype=torch.bfloat16, device="cpu" + ) + .zero_() + .contiguous() + ) + self.grad_down_lora_b = ( + torch.empty((self.num_experts, self.hidden_size, self.lora_rank), dtype=torch.bfloat16, device="cpu") + .zero_() + .contiguous() + ) + self._lora_initialized = True # If weights already loaded, update LoRA pointers in C++ @@ -550,12 +589,6 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): # Zero out gradient buffers buffer.grad_input_cpu.zero_() - buffer.grad_gate_lora_a.zero_() - buffer.grad_gate_lora_b.zero_() - buffer.grad_up_lora_a.zero_() - buffer.grad_up_lora_b.zero_() - buffer.grad_down_lora_a.zero_() - buffer.grad_down_lora_b.zero_() buffer.grad_weights.zero_() # Synchronize CUDA stream if input was on GPU to ensure data has arrived @@ -567,12 +600,12 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): self.moe.backward_task( buffer.grad_output_cpu.data_ptr(), buffer.grad_input_cpu.data_ptr(), - buffer.grad_gate_lora_a.data_ptr(), - buffer.grad_gate_lora_b.data_ptr(), - buffer.grad_up_lora_a.data_ptr(), - buffer.grad_up_lora_b.data_ptr(), - buffer.grad_down_lora_a.data_ptr(), - buffer.grad_down_lora_b.data_ptr(), + self.grad_gate_lora_a.data_ptr(), + self.grad_gate_lora_b.data_ptr(), + self.grad_up_lora_a.data_ptr(), + self.grad_up_lora_b.data_ptr(), + self.grad_down_lora_a.data_ptr(), + self.grad_down_lora_b.data_ptr(), buffer.grad_weights.data_ptr(), ) ) @@ -581,29 +614,19 @@ class AMXSFTMoEWrapper(BaseSFTMoEWrapper): # Decrease cache depth self._cache_depth -= 1 - # Accumulate LoRA gradients directly to param.grad if lora_params provided - # This avoids clone() by accumulating immediately before buffer is reused - # - First accumulation: param.grad is None, so we clone the buffer - # - Subsequent accumulations: param.grad exists, so we add_ directly (no clone) if lora_params is not None: lora_grad_mapping = [ - ("gate_lora_a", buffer.grad_gate_lora_a), - ("gate_lora_b", buffer.grad_gate_lora_b), - ("up_lora_a", buffer.grad_up_lora_a), - ("up_lora_b", buffer.grad_up_lora_b), - ("down_lora_a", buffer.grad_down_lora_a), - ("down_lora_b", buffer.grad_down_lora_b), + ("gate_lora_a", self.grad_gate_lora_a), + ("gate_lora_b", self.grad_gate_lora_b), + ("up_lora_a", self.grad_up_lora_a), + ("up_lora_b", self.grad_up_lora_b), + ("down_lora_a", self.grad_down_lora_a), + ("down_lora_b", self.grad_down_lora_b), ] for param_name, grad_buffer in lora_grad_mapping: param = lora_params[param_name] - # Convert to param's device and dtype - grad_converted = grad_buffer.to(device=param.device, dtype=param.dtype) if param.grad is None: - # First accumulation: must clone since buffer will be reused - param.grad = grad_converted.clone() - else: - # Subsequent accumulations: add directly (param.grad is independent) - param.grad.add_(grad_converted) + param.grad = grad_buffer # Return gradients: if output_device specified, transfer grad_input directly if output_device is not None: