mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
GatedDeltaNet: Add fused kernel for Qwen3.5 path
This commit is contained in:
@@ -107,6 +107,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("add", &add, "add");
|
||||
|
||||
m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op");
|
||||
m.def("gated_delta_net_fused_op_2", &gated_delta_net_fused_op_2, "gated_delta_net_fused_op_2");
|
||||
m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule");
|
||||
|
||||
m.def("argmax_sample", &argmax_sample, "argmax_sample");
|
||||
|
||||
@@ -14,6 +14,16 @@ void gated_delta_net_fused_op
|
||||
size_t v_head_dim
|
||||
);
|
||||
|
||||
void gated_delta_net_fused_op_2
|
||||
(
|
||||
const at::Tensor& b,
|
||||
const at::Tensor& a,
|
||||
const at::Tensor& dt_bias,
|
||||
const at::Tensor& a_log,
|
||||
at::Tensor& beta,
|
||||
at::Tensor& g
|
||||
);
|
||||
|
||||
void cuda_recurrent_gated_delta_rule
|
||||
(
|
||||
const at::Tensor& mixed_qkv,
|
||||
|
||||
@@ -16,6 +16,8 @@ using bfloat16 = __nv_bfloat16;
|
||||
#define R_THREADS MAX_HEAD_DIM
|
||||
#define SUBK 4
|
||||
|
||||
#define FUSED_OP_2_THREADS 512
|
||||
|
||||
__device__ __forceinline__ float _sigmoid_fast_exp(float x)
|
||||
{
|
||||
return 1.0f / (1.0f + __expf(-x));
|
||||
@@ -137,7 +139,7 @@ void gated_delta_net_fused_op_kernel
|
||||
float g = in_ba[base_ba + Ng + t];
|
||||
float bi = untrunc_bf16(dt_bias[out_va_off % Nv]);
|
||||
float al = untrunc_bf16(a_log[out_va_off % Nv]);
|
||||
out_g[out_va_off] = -softplus(g + bi) * expf(al);
|
||||
out_g[out_va_off] = -softplus(g + bi) * __expf(al);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -213,6 +215,96 @@ void gated_delta_net_fused_op
|
||||
}
|
||||
|
||||
|
||||
__global__ void gated_delta_net_fused_op_2_kernel
|
||||
(
|
||||
const float* __restrict__ in_b, // [B,S,H]
|
||||
const float* __restrict__ in_a, // [B,S,H]
|
||||
const bfloat16* __restrict__ in_dt_bias, // [H]
|
||||
const float* __restrict__ in_a_log, // [H]
|
||||
bfloat16* __restrict__ out_beta, // [B,S,H]
|
||||
float* __restrict__ out_g, // [B,S,H]
|
||||
int B,
|
||||
int S,
|
||||
int H,
|
||||
int rows_per_block
|
||||
)
|
||||
{
|
||||
int t = threadIdx.x % H;
|
||||
int row = blockIdx.x * rows_per_block + threadIdx.x / H;
|
||||
if (row >= B * S) return;
|
||||
|
||||
in_b += row * H + t;
|
||||
in_a += row * H + t;
|
||||
in_dt_bias += t;
|
||||
in_a_log += t;
|
||||
out_beta += row * H + t;
|
||||
out_g += row * H + t;
|
||||
|
||||
float beta = _sigmoid_fast_exp(*in_b);
|
||||
float dt_bias = untrunc_bf16(*in_dt_bias);
|
||||
float g = -softplus(*in_a + dt_bias) * __expf(*in_a_log);
|
||||
|
||||
*out_beta = trunc_bf16(beta);
|
||||
*out_g = g;
|
||||
}
|
||||
|
||||
/*
|
||||
For Qwen3.5, producing gate + beta tensors, downcast to bfloat16
|
||||
Transpose and qkv/z cast handled by Torch
|
||||
*/
|
||||
|
||||
void gated_delta_net_fused_op_2
|
||||
(
|
||||
const at::Tensor& b, // [B,S,H] float
|
||||
const at::Tensor& a, // [B,S,H] float
|
||||
const at::Tensor& dt_bias, // [H] bfloat16
|
||||
const at::Tensor& a_log, // [H] float
|
||||
at::Tensor& beta, // out [B,S,H] bfloat16
|
||||
at::Tensor& g // out [B,S,H] float
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(b.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(b, kFloat);
|
||||
TORCH_CHECK_DTYPE(a, kFloat);
|
||||
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(a_log, kFloat);
|
||||
TORCH_CHECK_DTYPE(beta, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kFloat);
|
||||
|
||||
TORCH_CHECK_SHAPES_FULL(b, a);
|
||||
TORCH_CHECK_SHAPES(b, 2, dt_bias, 0, 1);
|
||||
TORCH_CHECK_SHAPES(b, 2, a_log, 0, 1);
|
||||
TORCH_CHECK_SHAPES_FULL(b, beta);
|
||||
TORCH_CHECK_SHAPES_FULL(b, g);
|
||||
|
||||
size_t B = b.size(0);
|
||||
size_t S = b.size(1);
|
||||
size_t H = b.size(2);
|
||||
|
||||
int rows_per_block = FUSED_OP_2_THREADS / H;
|
||||
int threads = rows_per_block * H;
|
||||
int blocks = CEIL_DIVIDE(B * S, rows_per_block);
|
||||
|
||||
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const float*) b.data_ptr(),
|
||||
(const float*) a.data_ptr(),
|
||||
(const bfloat16*) dt_bias.data_ptr(),
|
||||
(const float*) a_log.data_ptr(),
|
||||
(bfloat16*) beta.data_ptr(),
|
||||
(float*) g.data_ptr(),
|
||||
B,
|
||||
S,
|
||||
H,
|
||||
rows_per_block
|
||||
);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
__global__ __launch_bounds__(R_THREADS * SUBK)
|
||||
void cuda_recurrent_gated_delta_rule_kernel
|
||||
(
|
||||
|
||||
@@ -578,6 +578,7 @@ class GatedDeltaNet(Module):
|
||||
self.v_head_dim
|
||||
)
|
||||
else:
|
||||
# TODO: Bound class and/or graph for this part
|
||||
qkv = self.qkv_proj.forward(x, params)
|
||||
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim)
|
||||
b = self.b_proj.forward(x, params)
|
||||
@@ -585,12 +586,18 @@ class GatedDeltaNet(Module):
|
||||
|
||||
mixed_qkv = qkv.transpose(1, 2).to(torch.bfloat16).contiguous()
|
||||
|
||||
dt_bias = self.dt_bias.float().view(1, 1, self.num_v_heads)
|
||||
a_log = self.a_log.view(1, 1, self.num_v_heads)
|
||||
beta = torch.sigmoid(b).to(torch.bfloat16)
|
||||
g = (-F.softplus(a + dt_bias) * torch.exp(a_log)).to(torch.float)
|
||||
beta = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
|
||||
g = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
|
||||
|
||||
ext.gated_delta_net_fused_op_2(
|
||||
b, a,
|
||||
self.dt_bias,
|
||||
self.a_log,
|
||||
beta, g
|
||||
)
|
||||
|
||||
# Convolution
|
||||
# TODO: Figure out an alternative or write a new kernel that won't require transposing qkv back and forth
|
||||
if conv_state is None:
|
||||
if save_state:
|
||||
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
|
||||
Reference in New Issue
Block a user