From f4c56f8c6d8e3b8f0bbd95ca5efd84e65da01efd Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Fri, 13 Mar 2026 00:44:37 +0100 Subject: [PATCH] GatedDeltaNet: Handle head sizes up to 256, divisible by down to 32, support beta scale (linear_allow_neg_eigval) --- exllamav3/exllamav3_ext/gdn.cuh | 6 +- exllamav3/exllamav3_ext/gnd.cu | 104 ++++++++++-------- .../libtorch/gated_delta_net.cpp | 3 +- .../exllamav3_ext/libtorch/gated_delta_net.h | 7 +- .../libtorch/gated_delta_net_bc.h | 6 +- exllamav3/modules/gated_delta_net.py | 11 +- 6 files changed, 81 insertions(+), 56 deletions(-) diff --git a/exllamav3/exllamav3_ext/gdn.cuh b/exllamav3/exllamav3_ext/gdn.cuh index 2da3233..592ee4c 100644 --- a/exllamav3/exllamav3_ext/gdn.cuh +++ b/exllamav3/exllamav3_ext/gdn.cuh @@ -11,7 +11,8 @@ void gated_delta_net_fused_op size_t num_k_heads, size_t num_v_heads, size_t k_head_dim, - size_t v_head_dim + size_t v_head_dim, + const float beta_scale ); void gated_delta_net_fused_op_2 @@ -21,7 +22,8 @@ void gated_delta_net_fused_op_2 const at::Tensor& dt_bias, const at::Tensor& a_log, at::Tensor& beta, - at::Tensor& g + at::Tensor& g, + const float beta_scale ); void cuda_recurrent_gated_delta_rule diff --git a/exllamav3/exllamav3_ext/gnd.cu b/exllamav3/exllamav3_ext/gnd.cu index 6834e25..9fc606d 100644 --- a/exllamav3/exllamav3_ext/gnd.cu +++ b/exllamav3/exllamav3_ext/gnd.cu @@ -9,11 +9,9 @@ #include using bfloat16 = __nv_bfloat16; -#define MAX_HEAD_DIM 128 #define MAX_K_HEADS 32 #define MAX_V_HEADS 64 -#define R_THREADS MAX_HEAD_DIM #define SUBK 4 #define FUSED_OP_2_THREADS 512 @@ -49,6 +47,7 @@ __device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshol return log1pf(__expf(x)); } +template __global__ __launch_bounds__(MAX_HEAD_DIM) void gated_delta_net_fused_op_kernel ( @@ -65,7 +64,8 @@ void gated_delta_net_fused_op_kernel const size_t Nk, const size_t Ng, const size_t Hk, - const size_t Hv + const size_t Hv, + const float beta_scale ) { const size_t Nv = Nk * Ng; @@ -143,7 +143,7 @@ void gated_delta_net_fused_op_kernel // beta = sigmoid(b).bfloat16() float b = in_ba[base_ba + t]; - out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b)); + out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b) * beta_scale); // g = -self.a_log.float().exp() * F.softplus(a + self.dt_bias.float()) float g = in_ba[base_ba + Ng + t]; @@ -172,7 +172,8 @@ void gated_delta_net_fused_op size_t num_k_heads, size_t num_v_heads, size_t k_head_dim, - size_t v_head_dim + size_t v_head_dim, + const float beta_scale ) { const at::cuda::OptionalCUDAGuard device_guard(mixed_qkvz.device()); @@ -206,20 +207,26 @@ void gated_delta_net_fused_op const int blocks = B * S * Nk; const int threads = MAX(Hk, Hv); - TORCH_CHECK(threads <= MAX_HEAD_DIM, "Max head dim exceeded"); - gated_delta_net_fused_op_kernel<<>> - ( - (const float*) mixed_qkvz.data_ptr(), - (const float*) mixed_ba.data_ptr(), - (const bfloat16*) dt_bias.data_ptr(), - (const bfloat16*) a_log.data_ptr(), - (bfloat16*) mixed_qkv.data_ptr(), - (bfloat16*) z.data_ptr(), - (bfloat16*) beta.data_ptr(), - (float*) g.data_ptr(), - B, S, Nk, Ng, Hk, Hv - ); + #define KERNEL_ARGS \ + (const float*) mixed_qkvz.data_ptr(), \ + (const float*) mixed_ba.data_ptr(), \ + (const bfloat16*) dt_bias.data_ptr(), \ + (const bfloat16*) a_log.data_ptr(), \ + (bfloat16*) mixed_qkv.data_ptr(), \ + (bfloat16*) z.data_ptr(), \ + (bfloat16*) beta.data_ptr(), \ + (float*) g.data_ptr(), \ + B, S, Nk, Ng, Hk, Hv, \ + beta_scale + + if (threads <= 128) + gated_delta_net_fused_op_kernel<128><<>>(KERNEL_ARGS); + else if (threads <= 256) + gated_delta_net_fused_op_kernel<256><<>>(KERNEL_ARGS); + else TORCH_CHECK(false, "Max head dim exceeded"); + + #undef KERNEL_ARGS cuda_check(cudaPeekAtLastError()); } @@ -236,7 +243,8 @@ __global__ void gated_delta_net_fused_op_2_kernel int B, int S, int H, - int rows_per_block + int rows_per_block, + const float beta_scale ) { int t = threadIdx.x % H; @@ -250,7 +258,7 @@ __global__ void gated_delta_net_fused_op_2_kernel out_beta += row * H + t; out_g += row * H + t; - float beta = _sigmoid_fast_exp(*in_b); + float beta = _sigmoid_fast_exp(*in_b) * beta_scale; float dt_bias = as_float(*in_dt_bias); float g = -softplus(*in_a + dt_bias) * __expf(as_float(*in_a_log)); @@ -270,7 +278,8 @@ void gated_delta_net_fused_op_2 const at::Tensor& dt_bias, // [H] bfloat16 const at::Tensor& a_log, // [H] float at::Tensor& beta, // out [B,S,H] bfloat16 - at::Tensor& g // out [B,S,H] float + at::Tensor& g, // out [B,S,H] float + const float beta_scale ) { const at::cuda::OptionalCUDAGuard device_guard(b.device()); @@ -309,7 +318,8 @@ void gated_delta_net_fused_op_2 B, \ S, \ H, \ - rows_per_block + rows_per_block, \ + beta_scale if (a_log_fp32) gated_delta_net_fused_op_2_kernel<<>>(ARGS(float)); @@ -323,7 +333,8 @@ void gated_delta_net_fused_op_2 } -__global__ __launch_bounds__(R_THREADS * SUBK) +template +__global__ __launch_bounds__(MAX_HEAD_DIM * SUBK) void cuda_recurrent_gated_delta_rule_kernel ( // k_dim = num_k_heads * k_head_dim @@ -435,10 +446,10 @@ void cuda_recurrent_gated_delta_rule_kernel float* rs_rd = gl_rs + t + bt * bts * v_head_dim; // TODO: Could use tensor cores - for (int i = 0; i < k_head_dim / 16 / SUBK; ++i) + for (int i = 0; i < k_head_dim / 8 / SUBK; ++i) { #pragma unroll - for (int j = 0; j < 16; ++j, rs_rd += v_head_dim, sh_k_rd++) + for (int j = 0; j < 8; ++j, rs_rd += v_head_dim, sh_k_rd++) sum = sum + *sh_k_rd * *rs_rd; } atomicAdd(sh_dot1 + t, sum); @@ -462,10 +473,10 @@ void cuda_recurrent_gated_delta_rule_kernel float* rs_rw = gl_rs + t + bt * bts * v_head_dim; // TODO: Could use tensor cores - for (int i = 0; i < k_head_dim / 16 / SUBK; ++i) + for (int i = 0; i < k_head_dim / 8 / SUBK; ++i) { #pragma unroll - for (int j = 0; j < 16; ++j, rs_rw += v_head_dim, sh_k_rd++, sh_q_rd++) + for (int j = 0; j < 8; ++j, rs_rw += v_head_dim, sh_k_rd++, sh_q_rd++) { // State update step, k x v float state = *rs_rw; @@ -518,11 +529,6 @@ void cuda_recurrent_gated_delta_rule int seqlen = mixed_qkv.size(1); int qkv_dim = mixed_qkv.size(2); - TORCH_CHECK(num_k_heads <= MAX_K_HEADS, "num_k_heads > MAX_K_HEADS"); - TORCH_CHECK(num_v_heads <= MAX_V_HEADS, "num_v_heads > MAX_V_HEADS"); - TORCH_CHECK(k_head_dim <= MAX_HEAD_DIM, "k_head_dim > MAX_HEAD_DIM"); - TORCH_CHECK(v_head_dim <= MAX_HEAD_DIM, "v_head_dim > MAX_HEAD_DIM"); - TORCH_CHECK_DTYPE(mixed_qkv, kBFloat16); TORCH_CHECK_DTYPE(g, kFloat); TORCH_CHECK_DTYPE(beta, kBFloat16); @@ -534,19 +540,25 @@ void cuda_recurrent_gated_delta_rule float scale = 1.0f / sqrtf(k_head_dim); - cuda_recurrent_gated_delta_rule_kernel<<>> - ( - (const bfloat16*) mixed_qkv.data_ptr(), - (const float*) g.data_ptr(), - (const bfloat16*) beta.data_ptr(), - (float*) recurrent_state.data_ptr(), - (bfloat16*) core_attn_out.data_ptr(), - bsz, - seqlen, - num_k_heads, - num_v_heads, - k_head_dim, - v_head_dim, + #define KERNEL_ARGS \ + (const bfloat16*) mixed_qkv.data_ptr(), \ + (const float*) g.data_ptr(), \ + (const bfloat16*) beta.data_ptr(), \ + (float*) recurrent_state.data_ptr(), \ + (bfloat16*) core_attn_out.data_ptr(), \ + bsz, \ + seqlen, \ + num_k_heads, \ + num_v_heads, \ + k_head_dim, \ + v_head_dim, \ scale - ); + + if (threads.x <= 128) + cuda_recurrent_gated_delta_rule_kernel<128><<>>(KERNEL_ARGS); + else if (threads.x <= 256) + cuda_recurrent_gated_delta_rule_kernel<256><<>>(KERNEL_ARGS); + else TORCH_CHECK(false, "Max head dim exceeded"); + + #undef KERNEL_ARGS } diff --git a/exllamav3/exllamav3_ext/libtorch/gated_delta_net.cpp b/exllamav3/exllamav3_ext/libtorch/gated_delta_net.cpp index 5090546..bd1ecc3 100644 --- a/exllamav3/exllamav3_ext/libtorch/gated_delta_net.cpp +++ b/exllamav3/exllamav3_ext/libtorch/gated_delta_net.cpp @@ -28,7 +28,8 @@ at::Tensor BC_GatedDeltaNet::run_bsz1_a num_k_heads, num_v_heads, k_head_dim, - v_head_dim + v_head_dim, + beta_scale ); return mixed_qkv; diff --git a/exllamav3/exllamav3_ext/libtorch/gated_delta_net.h b/exllamav3/exllamav3_ext/libtorch/gated_delta_net.h index ef34dc7..b0ebc9a 100644 --- a/exllamav3/exllamav3_ext/libtorch/gated_delta_net.h +++ b/exllamav3/exllamav3_ext/libtorch/gated_delta_net.h @@ -31,6 +31,7 @@ struct BC_GatedDeltaNet c10::optional conv1d_bias; std::shared_ptr norm; std::shared_ptr o_proj; + const float beta_scale; BC_GatedDeltaNet ( @@ -55,7 +56,8 @@ struct BC_GatedDeltaNet at::Tensor _conv1d_weight, c10::optional _conv1d_bias, std::shared_ptr _norm, - std::shared_ptr _o_proj + std::shared_ptr _o_proj, + const float _beta_scale ) : mixed_qkv (std::move(_mixed_qkv)), z (std::move(_z)), @@ -78,7 +80,8 @@ struct BC_GatedDeltaNet conv1d_weight (std::move(_conv1d_weight)), conv1d_bias (std::move(_conv1d_bias)), norm (_norm), - o_proj (_o_proj) + o_proj (_o_proj), + beta_scale (_beta_scale) {} at::Tensor run_bsz1_a diff --git a/exllamav3/exllamav3_ext/libtorch/gated_delta_net_bc.h b/exllamav3/exllamav3_ext/libtorch/gated_delta_net_bc.h index 20daa77..9681de8 100644 --- a/exllamav3/exllamav3_ext/libtorch/gated_delta_net_bc.h +++ b/exllamav3/exllamav3_ext/libtorch/gated_delta_net_bc.h @@ -22,7 +22,8 @@ py::class_>(m, "BC_GatedDelt at::Tensor, c10::optional, std::shared_ptr, - std::shared_ptr + std::shared_ptr, + float >(), py::arg("mixed_qkv"), py::arg("z"), @@ -45,7 +46,8 @@ py::class_>(m, "BC_GatedDelt py::arg("conv1d_weight"), py::arg("conv1d_bias"), py::arg("norm"), - py::arg("o_proj") + py::arg("o_proj"), + py::arg("beta_scale") ) .def("run_bsz1_a", &BC_GatedDeltaNet::run_bsz1_a) .def("run_bsz1_b", &BC_GatedDeltaNet::run_bsz1_b); diff --git a/exllamav3/modules/gated_delta_net.py b/exllamav3/modules/gated_delta_net.py index 6384db1..148a91e 100644 --- a/exllamav3/modules/gated_delta_net.py +++ b/exllamav3/modules/gated_delta_net.py @@ -280,6 +280,7 @@ class GatedDeltaNet(Module): num_v_heads: int, rms_norm_eps: float, conv_kernel_size: int, + beta_scale: float = 1.0, key_a_log: str | None = None, key_dt_bias: str | None = None, key_conv1d: str | None = None, @@ -308,6 +309,7 @@ class GatedDeltaNet(Module): self.conv_kernel_size = conv_kernel_size self.k_dim = self.k_head_dim * self.num_k_heads self.v_dim = self.v_head_dim * self.num_v_heads + self.beta_scale = beta_scale self.out_dtype = out_dtype @@ -437,7 +439,8 @@ class GatedDeltaNet(Module): self.conv1d_weight, self.conv1d_bias, self.norm.bc, - self.o_proj.inner.bc + self.o_proj.inner.bc, + self.beta_scale ) @override @@ -575,7 +578,8 @@ class GatedDeltaNet(Module): self.num_k_heads, self.num_v_heads, self.k_head_dim, - self.v_head_dim + self.v_head_dim, + self.beta_scale ) else: # TODO: Bound class and/or graph for this part @@ -593,7 +597,8 @@ class GatedDeltaNet(Module): b, a, self.dt_bias, self.a_log, - beta, g + beta, g, + self.beta_scale ) # Convolution