mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-07-01 04:08:10 +00:00
204 lines
5.2 KiB
Python
204 lines
5.2 KiB
Python
import torch
|
|
from sgl_kernel.debug_utils import maybe_wrap_debug_kernel
|
|
from sgl_kernel.load_utils import _load_architecture_specific_ops, _preload_cuda_library
|
|
|
|
# Initialize the ops library based on current GPU
|
|
common_ops = _load_architecture_specific_ops()
|
|
|
|
# Preload the CUDA library to avoid the issue of libcudart.so.12 not found
|
|
if torch.version.cuda is not None:
|
|
_preload_cuda_library()
|
|
|
|
|
|
from sgl_kernel.allreduce import *
|
|
from sgl_kernel.attention import (
|
|
cutlass_mla_decode,
|
|
cutlass_mla_get_workspace_size,
|
|
merge_state_v2,
|
|
)
|
|
from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_data
|
|
from sgl_kernel.elementwise import (
|
|
concat_mla_absorb_q,
|
|
concat_mla_k,
|
|
copy_to_gpu_no_ce,
|
|
fused_add_rmsnorm,
|
|
gelu_and_mul,
|
|
gelu_tanh_and_mul,
|
|
gemma_fused_add_rmsnorm,
|
|
gemma_rmsnorm,
|
|
rmsnorm,
|
|
rotary_embedding,
|
|
silu_and_mul,
|
|
)
|
|
from sgl_kernel.expert_specialization import (
|
|
es_fp8_blockwise_scaled_grouped_mm,
|
|
es_sm100_mxfp8_blockscaled_grouped_mm,
|
|
es_sm100_mxfp8_blockscaled_grouped_quant,
|
|
)
|
|
from sgl_kernel.gemm import (
|
|
awq_dequantize,
|
|
bmm_fp8,
|
|
dsv3_fused_a_gemm,
|
|
dsv3_router_gemm,
|
|
fp8_blockwise_scaled_mm,
|
|
fp8_scaled_mm,
|
|
gptq_gemm,
|
|
gptq_shuffle,
|
|
int8_scaled_mm,
|
|
qserve_w4a8_per_chn_gemm,
|
|
qserve_w4a8_per_group_gemm,
|
|
sgl_per_token_group_quant_8bit,
|
|
sgl_per_token_group_quant_fp8,
|
|
sgl_per_token_group_quant_int8,
|
|
sgl_per_token_quant_fp8,
|
|
shuffle_rows,
|
|
)
|
|
from sgl_kernel.grammar import apply_token_bitmask_inplace_cuda
|
|
from sgl_kernel.kvcacheio import (
|
|
transfer_kv_all_layer,
|
|
transfer_kv_all_layer_mla,
|
|
transfer_kv_per_layer,
|
|
transfer_kv_per_layer_mla,
|
|
)
|
|
from sgl_kernel.mamba import (
|
|
causal_conv1d_fn_cpu,
|
|
causal_conv1d_fwd,
|
|
causal_conv1d_update,
|
|
causal_conv1d_update_cpu,
|
|
chunk_gated_delta_rule_cpu,
|
|
)
|
|
from sgl_kernel.memory import weak_ref_tensor
|
|
from sgl_kernel.moe import (
|
|
apply_shuffle_mul_sum,
|
|
fp8_blockwise_scaled_grouped_mm,
|
|
fused_qk_norm_rope,
|
|
kimi_k2_moe_fused_gate,
|
|
moe_align_block_size,
|
|
moe_fused_gate,
|
|
moe_sum,
|
|
moe_sum_reduce,
|
|
prepare_moe_input,
|
|
topk_sigmoid,
|
|
topk_softmax,
|
|
)
|
|
from sgl_kernel.quantization import (
|
|
ggml_dequantize,
|
|
ggml_moe_a8,
|
|
ggml_moe_a8_vec,
|
|
ggml_moe_get_block_size,
|
|
ggml_mul_mat_a8,
|
|
ggml_mul_mat_vec_a8,
|
|
)
|
|
from sgl_kernel.sampling import (
|
|
top_k_renorm_prob,
|
|
top_p_renorm_prob,
|
|
)
|
|
from sgl_kernel.speculative import (
|
|
build_tree_kernel_efficient,
|
|
reconstruct_indices_from_tree_mask,
|
|
segment_packbits,
|
|
tree_speculative_sampling_target_only,
|
|
verify_tree_greedy,
|
|
)
|
|
from sgl_kernel.top_k import (
|
|
fast_topk,
|
|
fast_topk_transform_fused,
|
|
fast_topk_transform_ragged_fused,
|
|
fast_topk_v2,
|
|
)
|
|
from sgl_kernel.version import __version__
|
|
|
|
if torch.version.hip is not None:
|
|
from sgl_kernel.elementwise import gelu_quick
|
|
|
|
|
|
_DEBUG_EXPORT_NAMES = [
|
|
"apply_shuffle_mul_sum",
|
|
"apply_token_bitmask_inplace_cuda",
|
|
"awq_dequantize",
|
|
"bmm_fp8",
|
|
"build_tree_kernel_efficient",
|
|
"causal_conv1d_fwd",
|
|
"causal_conv1d_update",
|
|
"concat_mla_absorb_q",
|
|
"concat_mla_k",
|
|
"copy_to_gpu_no_ce",
|
|
"cutlass_mla_decode",
|
|
"cutlass_mla_get_workspace_size",
|
|
"dsv3_fused_a_gemm",
|
|
"dsv3_router_gemm",
|
|
"es_fp8_blockwise_scaled_grouped_mm",
|
|
"es_sm100_mxfp8_blockscaled_grouped_mm",
|
|
"es_sm100_mxfp8_blockscaled_grouped_quant",
|
|
"fast_topk",
|
|
"fast_topk_transform_fused",
|
|
"fast_topk_transform_ragged_fused",
|
|
"fast_topk_v2",
|
|
"fp8_blockwise_scaled_grouped_mm",
|
|
"fp8_blockwise_scaled_mm",
|
|
"fp8_scaled_mm",
|
|
"fused_add_rmsnorm",
|
|
"fused_qk_norm_rope",
|
|
"gelu_and_mul",
|
|
"gelu_tanh_and_mul",
|
|
"gemma_fused_add_rmsnorm",
|
|
"gemma_rmsnorm",
|
|
"gptq_gemm",
|
|
"gptq_shuffle",
|
|
"int8_scaled_mm",
|
|
"kimi_k2_moe_fused_gate",
|
|
"merge_state_v2",
|
|
"moe_align_block_size",
|
|
"moe_fused_gate",
|
|
"moe_sum",
|
|
"moe_sum_reduce",
|
|
"prepare_moe_input",
|
|
"qserve_w4a8_per_chn_gemm",
|
|
"qserve_w4a8_per_group_gemm",
|
|
"reconstruct_indices_from_tree_mask",
|
|
"rmsnorm",
|
|
"rotary_embedding",
|
|
"segment_packbits",
|
|
"sgl_per_token_group_quant_8bit",
|
|
"sgl_per_token_group_quant_fp8",
|
|
"sgl_per_token_group_quant_int8",
|
|
"sgl_per_token_quant_fp8",
|
|
"shuffle_rows",
|
|
"silu_and_mul",
|
|
"top_k_renorm_prob",
|
|
"top_p_renorm_prob",
|
|
"topk_sigmoid",
|
|
"topk_softmax",
|
|
"transfer_kv_all_layer",
|
|
"transfer_kv_all_layer_mla",
|
|
"transfer_kv_per_layer",
|
|
"transfer_kv_per_layer_mla",
|
|
"tree_speculative_sampling_target_only",
|
|
"verify_tree_greedy",
|
|
"weak_ref_tensor",
|
|
]
|
|
|
|
if torch.version.hip is not None:
|
|
_DEBUG_EXPORT_NAMES.append("gelu_quick")
|
|
|
|
for _name in _DEBUG_EXPORT_NAMES:
|
|
if _name in globals():
|
|
globals()[_name] = maybe_wrap_debug_kernel(
|
|
globals()[_name], f"sgl_kernel.{_name}"
|
|
)
|
|
|
|
del _name
|
|
del _DEBUG_EXPORT_NAMES
|
|
|
|
|
|
def create_greenctx_stream_by_value(*args, **kwargs):
|
|
from sgl_kernel.spatial import create_greenctx_stream_by_value as _impl
|
|
|
|
return _impl(*args, **kwargs)
|
|
|
|
|
|
def get_sm_available(*args, **kwargs):
|
|
from sgl_kernel.spatial import get_sm_available as _impl
|
|
|
|
return _impl(*args, **kwargs)
|