[fix](sft): align Python API with C++ backend after v5 refactor

- wrapper.py: pass gpu_experts_mask=None to KTMoEWrapper (required by C++ signature)
- layer.py: rename submit_forward_sft/sync_forward_sft to submit_forward/sync_forward
- autograd.py: rename sync_forward_sft to sync_forward

The sft-v5 refactor (commits 58d7eab, dd1da65) renamed Python-side method
calls but the C++ backend (AMXSFTMoEWrapper) still exposes the original
method names. This caused AttributeError on Qwen3.5-35B and other models.
This commit is contained in:
JimmyPeilinLi
2026-04-20 16:44:09 +00:00
parent dd1da65d90
commit 168e10f254
3 changed files with 7 additions and 6 deletions

View File

@@ -76,7 +76,7 @@ class KTMoEFunction(torch.autograd.Function):
# Rank 0: sync CPU result and split by real lengths
if rank == 0:
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
cpu_output = wrapper.sync_forward(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
@@ -96,7 +96,7 @@ class KTMoEFunction(torch.autograd.Function):
del output_flat
elif wrapper is not None:
# Single-GPU: sync directly
cpu_output = wrapper.sync_forward_sft(output_device=original_device)
cpu_output = wrapper.sync_forward(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype)
else:
# Broadcast-only rank (no wrapper)

View File

@@ -206,7 +206,7 @@ class KTMoELayerWrapper(nn.Module):
if rank == 0:
if self.wrapper is None:
raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.")
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
cpu_output = self.wrapper.sync_forward(output_device=original_device)
cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size)
offsets = _qlen_offsets(all_qlens_list)
scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)]
@@ -227,7 +227,7 @@ class KTMoELayerWrapper(nn.Module):
return output
if self.wrapper is not None:
cpu_output = self.wrapper.sync_forward_sft(output_device=original_device)
cpu_output = self.wrapper.sync_forward(output_device=original_device)
output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype)
return output
@@ -335,7 +335,7 @@ class KTMoELayerWrapper(nn.Module):
all_hs = torch.cat(gathered_hs, dim=0)
all_ids = torch.cat(gathered_ids, dim=0)
all_wts = torch.cat(gathered_wts, dim=0)
self.wrapper.submit_forward_sft(
self.wrapper.submit_forward(
all_hs,
all_ids,
all_wts,
@@ -364,7 +364,7 @@ class KTMoELayerWrapper(nn.Module):
submit_hs = input_flat.detach()
submit_ids = expert_ids.detach()
submit_wts = weights.detach()
self.wrapper.submit_forward_sft(
self.wrapper.submit_forward(
submit_hs,
submit_ids,
submit_wts,

View File

@@ -318,6 +318,7 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT
num_experts_per_tok=moe_config.num_experts_per_tok,
hidden_size=hidden_size,
moe_intermediate_size=moe_config.intermediate_size,
gpu_experts_mask=None,
num_gpu_experts=0,
cpuinfer_threads=getattr(cfg, "kt_num_threads", 1),
threadpool_count=threadpool_count,