Files
sglang/sgl-kernel/python/sgl_kernel/fused_moe.py
Xiaoyu Zhang fb04d43428 [kimi k2 thinking] Avoid useless torch.zeros_ (#13596)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-11-21 13:15:27 +08:00

58 lines
1.4 KiB
Python

from typing import Optional
import torch
def moe_wna16_marlin_gemm(
a: torch.Tensor,
c_or_none: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
b_zeros_or_none: Optional[torch.Tensor],
g_idx_or_none: Optional[torch.Tensor],
perm_or_none: Optional[torch.Tensor],
workspace: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
topk_weights: torch.Tensor,
moe_block_size: int,
top_k: int,
mul_topk_weights: bool,
is_ep: bool,
b_q_type_id: int,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool,
use_atomic_add: bool,
use_fp32_reduce: bool,
is_zp_float: bool,
):
return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default(
a,
c_or_none,
b_q_weight,
b_scales,
b_zeros_or_none,
g_idx_or_none,
perm_or_none,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
topk_weights,
moe_block_size=moe_block_size,
top_k=top_k,
mul_topk_weights=mul_topk_weights,
is_ep=is_ep,
b_q_type_id=b_q_type_id,
size_m=size_m,
size_n=size_n,
size_k=size_k,
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=is_zp_float,
)