from typing import Optional import torch from sgl_kernel.utils import _get_cache_buf def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.int8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, ) def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.fp8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def _bmm_fp8_internal( workspace_buffer: torch.Tensor, A: torch.Tensor, B: torch.Tensor, D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() torch.ops.sgl_kernel.bmm_fp8.default( A, B, D, A_scale, B_scale, workspace_buffer, cublas_handle, ) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty( (A.shape[0], A.shape[1], B.shape[2]), device=A.device, dtype=dtype, ) workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) return out def dsv3_fused_a_gemm( mat_a: torch.Tensor, mat_b: torch.Tensor, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is None: output = torch.empty( (mat_a.shape[0], mat_b.shape[1]), device=mat_a.device, dtype=mat_a.dtype, ) torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b) return output def sgl_per_token_group_quant_8bit( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, group_size: int, eps: float, fp8_min: float, fp8_max: float, scale_ue8m0: bool = False, fuse_silu_and_mul: bool = False, masked_m: Optional[torch.Tensor] = None, enable_v2: Optional[bool] = None, ) -> None: _V2_KERNEL_SUPPORTED_GROUP_SIZES = [16, 32, 64, 128] if enable_v2 is None: enable_v2 = group_size in _V2_KERNEL_SUPPORTED_GROUP_SIZES if enable_v2: return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0, fuse_silu_and_mul, masked_m, ) assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul" assert masked_m is None, "only v2 support masked_m" torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 ) # For legacy usage sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s) def qserve_w4a8_per_chn_gemm( in_feats: torch.Tensor, kernel: torch.Tensor, wscales: torch.Tensor, ascales: torch.Tensor, w_szs: torch.Tensor, a_ssums: torch.Tensor, out_feats: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out_feats is None: # NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now out_feats = torch.empty( (in_feats.shape[0], kernel.shape[0]), device=in_feats.device, dtype=torch.float16, ) torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default( in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats ) return out_feats def qserve_w4a8_per_group_gemm( in_feats: torch.Tensor, kernel: torch.Tensor, zeros: torch.Tensor, scales_i8: torch.Tensor, wscales: torch.Tensor, ascales: torch.Tensor, out_feats: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out_feats is None: # NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now out_feats = torch.empty( (in_feats.shape[0], kernel.shape[0]), device=in_feats.device, dtype=torch.float16, ) torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default( in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats ) return out_feats def dsv3_router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: output = torch.empty( hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, dtype=out_dtype, ) torch.ops.sgl_kernel.dsv3_router_gemm( output, hidden_states, router_weights, ) return output def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): output_tensor = torch.empty( output_tensor_shape, device=input_tensor.device, dtype=input_tensor.dtype, ) torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor) return output_tensor # GPTQ kernels def gptq_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_shuffle: bool, bit: int, ) -> torch.Tensor: return torch.ops.sgl_kernel.gptq_gemm( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit ) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)