Files
sglang/sgl-kernel/python/sgl_kernel/musa.py
2026-05-07 20:38:33 -07:00

170 lines
4.1 KiB
Python

from typing import Optional
import torch
def musa_batched_rotary_embedding_contiguous(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor,
) -> None:
return torch.ops.sgl_kernel.musa_batched_rotary_embedding_contiguous(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
rot_dim,
cos_sin_cache_offsets,
)
def musa_rotary_embedding_contiguous(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
return torch.ops.sgl_kernel.musa_rotary_embedding_contiguous(
positions,
query,
key,
head_size,
cos_sin_cache,
is_neox,
)
def musa_fused_moe_gemv(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale,
B_scale,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
mul_routed_weight: bool,
topk: int,
use_int4_w4a16: bool,
use_swigelu: bool,
) -> None:
return torch.ops.sgl_kernel.musa_fused_moe_gemv(
A,
B,
C,
A_scale,
B_scale,
topk_weights,
topk_ids,
mul_routed_weight,
topk,
use_int4_w4a16,
use_swigelu,
)
def musa_fused_gemv(
x: torch.Tensor,
qweight: torch.Tensor,
x_scales: Optional[torch.Tensor] = None,
qweight_scales: Optional[torch.Tensor] = None,
use_swigelu: bool = False,
use_rms_norm: bool = False,
gamma: Optional[torch.Tensor] = None,
eps: float = 1e-6,
):
use_int4_w4a16 = False
out_shape = x.shape[:-1] + (
qweight.shape[0] if not use_swigelu else qweight.shape[0] // 2,
)
assert not (
use_swigelu and use_rms_norm
), "gemv only fused one activation (swigelu or rms_norm)!"
if use_rms_norm:
if gamma is None:
assert False, "rms_norm gamma is None!"
# fp8 grouped matmul
if qweight.dtype == torch.float8_e4m3fn:
assert qweight_scales is not None, "FP8 grouped matmul weight scales is None!"
output = torch.empty(out_shape, device=x.device, dtype=torch.bfloat16)
torch.ops.sgl_kernel.musa_fused_gemv(
x,
qweight,
output,
x_scales,
qweight_scales,
use_int4_w4a16,
use_swigelu,
use_rms_norm,
gamma,
eps,
)
return output
# w4a16 gemv
elif qweight_scales is not None:
assert (
x.dtype == torch.bfloat16 or x.dtype == torch.float16
), "W4A16 gemv only support bfloat16 or float16!"
use_int4_w4a16 = True
out_shape = x.shape[:-1] + (
qweight.shape[0] if not use_swigelu else qweight.shape[0] // 2,
)
output = torch.empty(out_shape, device=x.device, dtype=x.dtype)
torch.ops.sgl_kernel.musa_fused_gemv(
x,
qweight,
output,
None,
qweight_scales,
use_int4_w4a16,
use_swigelu,
use_rms_norm,
gamma,
eps,
)
return output
# general gemv
else:
output = torch.empty(out_shape, device=x.device, dtype=x.dtype)
torch.ops.sgl_kernel.musa_fused_gemv(
x,
qweight,
output,
None,
None,
use_int4_w4a16,
use_swigelu,
use_rms_norm,
gamma,
eps,
)
return output
def musa_fused_mul_add(
self: torch.Tensor,
bias: Optional[torch.Tensor],
scale: Optional[float],
accurate: bool = True,
):
# if accurate == False, then we call inplace op: bias += (self * scale)
if not accurate:
bias.add_(self, alpha=scale)
return bias
# otherwise, we call custom outplace op, act: output = self * scale + bias
output = torch.empty_like(self)
torch.ops.sgl_kernel.musa_fused_mul_add(output, self, bias, scale)
return output