mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 11:48:01 +00:00
240 lines
6.3 KiB
Python
240 lines
6.3 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
from sgl_kernel.utils import _get_cache_buf
|
|
|
|
|
|
def awq_dequantize(
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
) -> torch.ByteTensor:
|
|
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
|
|
|
|
|
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|
return torch.ops.sgl_kernel.int8_scaled_mm.default(
|
|
mat_a,
|
|
mat_b,
|
|
scales_a,
|
|
scales_b,
|
|
out_dtype,
|
|
bias,
|
|
)
|
|
|
|
|
|
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
|
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
|
|
mat_a,
|
|
mat_b,
|
|
scales_a,
|
|
scales_b,
|
|
out_dtype,
|
|
)
|
|
|
|
|
|
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
|
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
|
|
mat_a,
|
|
mat_b,
|
|
scales_a,
|
|
scales_b,
|
|
out_dtype,
|
|
bias,
|
|
)
|
|
|
|
|
|
def _bmm_fp8_internal(
|
|
workspace_buffer: torch.Tensor,
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
D: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
) -> None:
|
|
cublas_handle = torch.cuda.current_blas_handle()
|
|
torch.ops.sgl_kernel.bmm_fp8.default(
|
|
A,
|
|
B,
|
|
D,
|
|
A_scale,
|
|
B_scale,
|
|
workspace_buffer,
|
|
cublas_handle,
|
|
)
|
|
|
|
|
|
def bmm_fp8(
|
|
A: torch.Tensor,
|
|
B: torch.Tensor,
|
|
A_scale: torch.Tensor,
|
|
B_scale: torch.Tensor,
|
|
dtype: torch.dtype,
|
|
out: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if out is None:
|
|
out = torch.empty(
|
|
(A.shape[0], A.shape[1], B.shape[2]),
|
|
device=A.device,
|
|
dtype=dtype,
|
|
)
|
|
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
|
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
|
return out
|
|
|
|
|
|
def dsv3_fused_a_gemm(
|
|
mat_a: torch.Tensor,
|
|
mat_b: torch.Tensor,
|
|
output: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if output is None:
|
|
output = torch.empty(
|
|
(mat_a.shape[0], mat_b.shape[1]),
|
|
device=mat_a.device,
|
|
dtype=mat_a.dtype,
|
|
)
|
|
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
|
|
return output
|
|
|
|
|
|
def sgl_per_token_group_quant_8bit(
|
|
input: torch.Tensor,
|
|
output_q: torch.Tensor,
|
|
output_s: torch.Tensor,
|
|
group_size: int,
|
|
eps: float,
|
|
fp8_min: float,
|
|
fp8_max: float,
|
|
scale_ue8m0: bool = False,
|
|
fuse_silu_and_mul: bool = False,
|
|
masked_m: Optional[torch.Tensor] = None,
|
|
enable_v2: Optional[bool] = None,
|
|
) -> None:
|
|
_V2_KERNEL_SUPPORTED_GROUP_SIZES = [16, 32, 64, 128]
|
|
if enable_v2 is None:
|
|
enable_v2 = group_size in _V2_KERNEL_SUPPORTED_GROUP_SIZES
|
|
|
|
if enable_v2:
|
|
return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default(
|
|
input,
|
|
output_q,
|
|
output_s,
|
|
group_size,
|
|
eps,
|
|
fp8_min,
|
|
fp8_max,
|
|
scale_ue8m0,
|
|
fuse_silu_and_mul,
|
|
masked_m,
|
|
)
|
|
|
|
assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul"
|
|
assert masked_m is None, "only v2 support masked_m"
|
|
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
|
|
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
|
|
)
|
|
|
|
|
|
# For legacy usage
|
|
sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit
|
|
sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit
|
|
|
|
|
|
def sgl_per_token_quant_fp8(
|
|
input: torch.Tensor,
|
|
output_q: torch.Tensor,
|
|
output_s: torch.Tensor,
|
|
) -> None:
|
|
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
|
|
|
|
|
|
def qserve_w4a8_per_chn_gemm(
|
|
in_feats: torch.Tensor,
|
|
kernel: torch.Tensor,
|
|
wscales: torch.Tensor,
|
|
ascales: torch.Tensor,
|
|
w_szs: torch.Tensor,
|
|
a_ssums: torch.Tensor,
|
|
out_feats: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if out_feats is None:
|
|
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
|
|
out_feats = torch.empty(
|
|
(in_feats.shape[0], kernel.shape[0]),
|
|
device=in_feats.device,
|
|
dtype=torch.float16,
|
|
)
|
|
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
|
|
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
|
|
)
|
|
return out_feats
|
|
|
|
|
|
def qserve_w4a8_per_group_gemm(
|
|
in_feats: torch.Tensor,
|
|
kernel: torch.Tensor,
|
|
zeros: torch.Tensor,
|
|
scales_i8: torch.Tensor,
|
|
wscales: torch.Tensor,
|
|
ascales: torch.Tensor,
|
|
out_feats: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
if out_feats is None:
|
|
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
|
|
out_feats = torch.empty(
|
|
(in_feats.shape[0], kernel.shape[0]),
|
|
device=in_feats.device,
|
|
dtype=torch.float16,
|
|
)
|
|
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
|
|
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
|
|
)
|
|
return out_feats
|
|
|
|
|
|
def dsv3_router_gemm(
|
|
hidden_states: torch.Tensor,
|
|
router_weights: torch.Tensor,
|
|
out_dtype: torch.dtype = torch.bfloat16,
|
|
) -> torch.Tensor:
|
|
output = torch.empty(
|
|
hidden_states.shape[0],
|
|
router_weights.shape[0],
|
|
device=hidden_states.device,
|
|
dtype=out_dtype,
|
|
)
|
|
torch.ops.sgl_kernel.dsv3_router_gemm(
|
|
output,
|
|
hidden_states,
|
|
router_weights,
|
|
)
|
|
return output
|
|
|
|
|
|
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
|
|
output_tensor = torch.empty(
|
|
output_tensor_shape,
|
|
device=input_tensor.device,
|
|
dtype=input_tensor.dtype,
|
|
)
|
|
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
|
|
return output_tensor
|
|
|
|
|
|
# GPTQ kernels
|
|
def gptq_gemm(
|
|
a: torch.Tensor,
|
|
b_q_weight: torch.Tensor,
|
|
b_gptq_qzeros: torch.Tensor,
|
|
b_gptq_scales: torch.Tensor,
|
|
b_g_idx: torch.Tensor,
|
|
use_shuffle: bool,
|
|
bit: int,
|
|
) -> torch.Tensor:
|
|
return torch.ops.sgl_kernel.gptq_gemm(
|
|
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
|
|
)
|
|
|
|
|
|
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
|
|
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|