mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 11:48:01 +00:00
Co-authored-by: Ma Mingfei <mingfei.ma@intel.com> Co-authored-by: Fan Yin <1106310035@qq.com>
121 lines
2.5 KiB
Python
121 lines
2.5 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
# mamba
|
|
def causal_conv1d_fwd(
|
|
x: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias_: Optional[torch.Tensor],
|
|
conv_states: Optional[torch.Tensor],
|
|
query_start_loc: Optional[torch.Tensor],
|
|
cache_indices: Optional[torch.Tensor],
|
|
has_initial_state: Optional[torch.Tensor],
|
|
silu_activation: bool,
|
|
pad_slot_id: int,
|
|
):
|
|
torch.ops.sgl_kernel.causal_conv1d_fwd(
|
|
x,
|
|
weight,
|
|
bias_,
|
|
conv_states,
|
|
query_start_loc,
|
|
cache_indices,
|
|
has_initial_state,
|
|
silu_activation,
|
|
pad_slot_id,
|
|
)
|
|
|
|
|
|
def causal_conv1d_update(
|
|
x: torch.Tensor,
|
|
conv_state: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias_: Optional[torch.Tensor],
|
|
silu_activation: bool,
|
|
cache_seqlens: Optional[torch.Tensor],
|
|
conv_state_indices: Optional[torch.Tensor],
|
|
pad_slot_id: int,
|
|
):
|
|
torch.ops.sgl_kernel.causal_conv1d_update(
|
|
x,
|
|
conv_state,
|
|
weight,
|
|
bias_,
|
|
silu_activation,
|
|
cache_seqlens,
|
|
conv_state_indices,
|
|
pad_slot_id,
|
|
)
|
|
|
|
|
|
def causal_conv1d_fn_cpu(
|
|
mixed_qkv_transposed,
|
|
conv_weights,
|
|
bias,
|
|
activation,
|
|
conv_states,
|
|
has_initial_state,
|
|
cache_indices,
|
|
query_start_loc,
|
|
seq_lens_cpu,
|
|
):
|
|
return torch.ops.sgl_kernel.causal_conv1d_fwd_cpu(
|
|
mixed_qkv_transposed,
|
|
conv_weights,
|
|
bias,
|
|
conv_states,
|
|
query_start_loc,
|
|
cache_indices,
|
|
has_initial_state,
|
|
activation == "silu",
|
|
-1,
|
|
True,
|
|
)
|
|
|
|
|
|
def causal_conv1d_update_cpu(
|
|
mixed_qkv, conv_states, conv_weights, bias, activation, conv_state_indices
|
|
):
|
|
return torch.ops.sgl_kernel.causal_conv1d_update_cpu(
|
|
mixed_qkv,
|
|
conv_states,
|
|
conv_weights,
|
|
bias,
|
|
activation == "silu",
|
|
None,
|
|
conv_state_indices,
|
|
-1,
|
|
True,
|
|
)
|
|
|
|
|
|
def chunk_gated_delta_rule_cpu(
|
|
q,
|
|
k,
|
|
v,
|
|
g,
|
|
beta,
|
|
initial_state,
|
|
cu_seqlens,
|
|
head_first,
|
|
use_qk_l2norm_in_kernel,
|
|
):
|
|
core_attn_out, last_recurrent_state = (
|
|
torch.ops.sgl_kernel.chunk_gated_delta_rule_cpu(
|
|
q,
|
|
k,
|
|
v,
|
|
g,
|
|
beta,
|
|
initial_state,
|
|
True, # output_final_state
|
|
cu_seqlens,
|
|
head_first,
|
|
use_qk_l2norm_in_kernel,
|
|
)
|
|
)
|
|
h = None # Todo: add return h support
|
|
return core_attn_out, last_recurrent_state, h
|