GatedDeltaNet: Add fused kernel for Qwen3.5 path

This commit is contained in:
turboderp
2026-03-03 06:10:06 +01:00
parent e5b522872b
commit d3d76d38f8
4 changed files with 115 additions and 5 deletions

View File

@@ -107,6 +107,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("add", &add, "add");
m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op");
m.def("gated_delta_net_fused_op_2", &gated_delta_net_fused_op_2, "gated_delta_net_fused_op_2");
m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule");
m.def("argmax_sample", &argmax_sample, "argmax_sample");

View File

@@ -14,6 +14,16 @@ void gated_delta_net_fused_op
size_t v_head_dim
);
void gated_delta_net_fused_op_2
(
const at::Tensor& b,
const at::Tensor& a,
const at::Tensor& dt_bias,
const at::Tensor& a_log,
at::Tensor& beta,
at::Tensor& g
);
void cuda_recurrent_gated_delta_rule
(
const at::Tensor& mixed_qkv,

View File

@@ -16,6 +16,8 @@ using bfloat16 = __nv_bfloat16;
#define R_THREADS MAX_HEAD_DIM
#define SUBK 4
#define FUSED_OP_2_THREADS 512
__device__ __forceinline__ float _sigmoid_fast_exp(float x)
{
return 1.0f / (1.0f + __expf(-x));
@@ -137,7 +139,7 @@ void gated_delta_net_fused_op_kernel
float g = in_ba[base_ba + Ng + t];
float bi = untrunc_bf16(dt_bias[out_va_off % Nv]);
float al = untrunc_bf16(a_log[out_va_off % Nv]);
out_g[out_va_off] = -softplus(g + bi) * expf(al);
out_g[out_va_off] = -softplus(g + bi) * __expf(al);
}
}
}
@@ -213,6 +215,96 @@ void gated_delta_net_fused_op
}
__global__ void gated_delta_net_fused_op_2_kernel
(
const float* __restrict__ in_b, // [B,S,H]
const float* __restrict__ in_a, // [B,S,H]
const bfloat16* __restrict__ in_dt_bias, // [H]
const float* __restrict__ in_a_log, // [H]
bfloat16* __restrict__ out_beta, // [B,S,H]
float* __restrict__ out_g, // [B,S,H]
int B,
int S,
int H,
int rows_per_block
)
{
int t = threadIdx.x % H;
int row = blockIdx.x * rows_per_block + threadIdx.x / H;
if (row >= B * S) return;
in_b += row * H + t;
in_a += row * H + t;
in_dt_bias += t;
in_a_log += t;
out_beta += row * H + t;
out_g += row * H + t;
float beta = _sigmoid_fast_exp(*in_b);
float dt_bias = untrunc_bf16(*in_dt_bias);
float g = -softplus(*in_a + dt_bias) * __expf(*in_a_log);
*out_beta = trunc_bf16(beta);
*out_g = g;
}
/*
For Qwen3.5, producing gate + beta tensors, downcast to bfloat16
Transpose and qkv/z cast handled by Torch
*/
void gated_delta_net_fused_op_2
(
const at::Tensor& b, // [B,S,H] float
const at::Tensor& a, // [B,S,H] float
const at::Tensor& dt_bias, // [H] bfloat16
const at::Tensor& a_log, // [H] float
at::Tensor& beta, // out [B,S,H] bfloat16
at::Tensor& g // out [B,S,H] float
)
{
const at::cuda::OptionalCUDAGuard device_guard(b.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(b, kFloat);
TORCH_CHECK_DTYPE(a, kFloat);
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
TORCH_CHECK_DTYPE(a_log, kFloat);
TORCH_CHECK_DTYPE(beta, kBFloat16);
TORCH_CHECK_DTYPE(g, kFloat);
TORCH_CHECK_SHAPES_FULL(b, a);
TORCH_CHECK_SHAPES(b, 2, dt_bias, 0, 1);
TORCH_CHECK_SHAPES(b, 2, a_log, 0, 1);
TORCH_CHECK_SHAPES_FULL(b, beta);
TORCH_CHECK_SHAPES_FULL(b, g);
size_t B = b.size(0);
size_t S = b.size(1);
size_t H = b.size(2);
int rows_per_block = FUSED_OP_2_THREADS / H;
int threads = rows_per_block * H;
int blocks = CEIL_DIVIDE(B * S, rows_per_block);
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>
(
(const float*) b.data_ptr(),
(const float*) a.data_ptr(),
(const bfloat16*) dt_bias.data_ptr(),
(const float*) a_log.data_ptr(),
(bfloat16*) beta.data_ptr(),
(float*) g.data_ptr(),
B,
S,
H,
rows_per_block
);
cuda_check(cudaPeekAtLastError());
}
__global__ __launch_bounds__(R_THREADS * SUBK)
void cuda_recurrent_gated_delta_rule_kernel
(

View File

@@ -578,6 +578,7 @@ class GatedDeltaNet(Module):
self.v_head_dim
)
else:
# TODO: Bound class and/or graph for this part
qkv = self.qkv_proj.forward(x, params)
z = self.z_proj.forward(x, params).view(bsz, seqlen, self.num_v_heads, self.v_head_dim)
b = self.b_proj.forward(x, params)
@@ -585,12 +586,18 @@ class GatedDeltaNet(Module):
mixed_qkv = qkv.transpose(1, 2).to(torch.bfloat16).contiguous()
dt_bias = self.dt_bias.float().view(1, 1, self.num_v_heads)
a_log = self.a_log.view(1, 1, self.num_v_heads)
beta = torch.sigmoid(b).to(torch.bfloat16)
g = (-F.softplus(a + dt_bias) * torch.exp(a_log)).to(torch.float)
beta = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
g = torch.empty((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
ext.gated_delta_net_fused_op_2(
b, a,
self.dt_bias,
self.a_log,
beta, g
)
# Convolution
# TODO: Figure out an alternative or write a new kernel that won't require transposing qkv back and forth
if conv_state is None:
if save_state:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))