from typing import Optional import torch def moe_wna16_marlin_gemm( a: torch.Tensor, c_or_none: Optional[torch.Tensor], b_q_weight: torch.Tensor, b_bias_or_none: Optional[torch.Tensor], b_scales: torch.Tensor, global_scale_or_none: Optional[torch.Tensor], b_zeros_or_none: Optional[torch.Tensor], g_idx_or_none: Optional[torch.Tensor], perm_or_none: Optional[torch.Tensor], workspace: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, topk_weights: torch.Tensor, moe_block_size: int, top_k: int, mul_topk_weights: bool, is_ep: bool, b_q_type_id: int, size_m: int, size_n: int, size_k: int, is_k_full: bool, use_atomic_add: bool, use_fp32_reduce: bool, is_zp_float: bool, ): return torch.ops.sgl_kernel.moe_wna16_marlin_gemm.default( a, c_or_none, b_q_weight, b_bias_or_none, b_scales, global_scale_or_none, b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, sorted_token_ids, expert_ids, num_tokens_post_padded, topk_weights, moe_block_size=moe_block_size, top_k=top_k, mul_topk_weights=mul_topk_weights, is_ep=is_ep, b_q_type_id=b_q_type_id, size_m=size_m, size_n=size_n, size_k=size_k, is_k_full=is_k_full, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=is_zp_float, )