diff --git a/kt-kernel/python/sft/autograd.py b/kt-kernel/python/sft/autograd.py index 9b2934e4..0264e9de 100644 --- a/kt-kernel/python/sft/autograd.py +++ b/kt-kernel/python/sft/autograd.py @@ -76,7 +76,7 @@ class KTMoEFunction(torch.autograd.Function): # Rank 0: sync CPU result and split by real lengths if rank == 0: - cpu_output = wrapper.sync_forward_sft(output_device=original_device) + cpu_output = wrapper.sync_forward(output_device=original_device) cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size) offsets = _qlen_offsets(all_qlens_list) scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] @@ -96,7 +96,7 @@ class KTMoEFunction(torch.autograd.Function): del output_flat elif wrapper is not None: # Single-GPU: sync directly - cpu_output = wrapper.sync_forward_sft(output_device=original_device) + cpu_output = wrapper.sync_forward(output_device=original_device) output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype) else: # Broadcast-only rank (no wrapper) diff --git a/kt-kernel/python/sft/layer.py b/kt-kernel/python/sft/layer.py index fa889721..e4cb2b65 100644 --- a/kt-kernel/python/sft/layer.py +++ b/kt-kernel/python/sft/layer.py @@ -206,7 +206,7 @@ class KTMoELayerWrapper(nn.Module): if rank == 0: if self.wrapper is None: raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.") - cpu_output = self.wrapper.sync_forward_sft(output_device=original_device) + cpu_output = self.wrapper.sync_forward(output_device=original_device) cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size) offsets = _qlen_offsets(all_qlens_list) scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] @@ -227,7 +227,7 @@ class KTMoELayerWrapper(nn.Module): return output if self.wrapper is not None: - cpu_output = self.wrapper.sync_forward_sft(output_device=original_device) + cpu_output = self.wrapper.sync_forward(output_device=original_device) output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype) return output @@ -335,7 +335,7 @@ class KTMoELayerWrapper(nn.Module): all_hs = torch.cat(gathered_hs, dim=0) all_ids = torch.cat(gathered_ids, dim=0) all_wts = torch.cat(gathered_wts, dim=0) - self.wrapper.submit_forward_sft( + self.wrapper.submit_forward( all_hs, all_ids, all_wts, @@ -364,7 +364,7 @@ class KTMoELayerWrapper(nn.Module): submit_hs = input_flat.detach() submit_ids = expert_ids.detach() submit_wts = weights.detach() - self.wrapper.submit_forward_sft( + self.wrapper.submit_forward( submit_hs, submit_ids, submit_wts, diff --git a/kt-kernel/python/sft/wrapper.py b/kt-kernel/python/sft/wrapper.py index 06706716..a53ea88a 100644 --- a/kt-kernel/python/sft/wrapper.py +++ b/kt-kernel/python/sft/wrapper.py @@ -318,6 +318,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT num_experts_per_tok=moe_config.num_experts_per_tok, hidden_size=hidden_size, moe_intermediate_size=moe_config.intermediate_size, + gpu_experts_mask=None, num_gpu_experts=0, cpuinfer_threads=getattr(cfg, "kt_num_threads", 1), threadpool_count=threadpool_count,