mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-05-11 16:30:12 +00:00
GatedDeltaNet: Handle head sizes up to 256, divisible by down to 32, support beta scale (linear_allow_neg_eigval)
This commit is contained in:
@@ -11,7 +11,8 @@ void gated_delta_net_fused_op
|
||||
size_t num_k_heads,
|
||||
size_t num_v_heads,
|
||||
size_t k_head_dim,
|
||||
size_t v_head_dim
|
||||
size_t v_head_dim,
|
||||
const float beta_scale
|
||||
);
|
||||
|
||||
void gated_delta_net_fused_op_2
|
||||
@@ -21,7 +22,8 @@ void gated_delta_net_fused_op_2
|
||||
const at::Tensor& dt_bias,
|
||||
const at::Tensor& a_log,
|
||||
at::Tensor& beta,
|
||||
at::Tensor& g
|
||||
at::Tensor& g,
|
||||
const float beta_scale
|
||||
);
|
||||
|
||||
void cuda_recurrent_gated_delta_rule
|
||||
|
||||
@@ -9,11 +9,9 @@
|
||||
#include <cmath>
|
||||
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
#define MAX_HEAD_DIM 128
|
||||
#define MAX_K_HEADS 32
|
||||
#define MAX_V_HEADS 64
|
||||
|
||||
#define R_THREADS MAX_HEAD_DIM
|
||||
#define SUBK 4
|
||||
|
||||
#define FUSED_OP_2_THREADS 512
|
||||
@@ -49,6 +47,7 @@ __device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshol
|
||||
return log1pf(__expf(x));
|
||||
}
|
||||
|
||||
template<int MAX_HEAD_DIM>
|
||||
__global__ __launch_bounds__(MAX_HEAD_DIM)
|
||||
void gated_delta_net_fused_op_kernel
|
||||
(
|
||||
@@ -65,7 +64,8 @@ void gated_delta_net_fused_op_kernel
|
||||
const size_t Nk,
|
||||
const size_t Ng,
|
||||
const size_t Hk,
|
||||
const size_t Hv
|
||||
const size_t Hv,
|
||||
const float beta_scale
|
||||
)
|
||||
{
|
||||
const size_t Nv = Nk * Ng;
|
||||
@@ -143,7 +143,7 @@ void gated_delta_net_fused_op_kernel
|
||||
|
||||
// beta = sigmoid(b).bfloat16()
|
||||
float b = in_ba[base_ba + t];
|
||||
out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b));
|
||||
out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b) * beta_scale);
|
||||
|
||||
// g = -self.a_log.float().exp() * F.softplus(a + self.dt_bias.float())
|
||||
float g = in_ba[base_ba + Ng + t];
|
||||
@@ -172,7 +172,8 @@ void gated_delta_net_fused_op
|
||||
size_t num_k_heads,
|
||||
size_t num_v_heads,
|
||||
size_t k_head_dim,
|
||||
size_t v_head_dim
|
||||
size_t v_head_dim,
|
||||
const float beta_scale
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(mixed_qkvz.device());
|
||||
@@ -206,20 +207,26 @@ void gated_delta_net_fused_op
|
||||
|
||||
const int blocks = B * S * Nk;
|
||||
const int threads = MAX(Hk, Hv);
|
||||
TORCH_CHECK(threads <= MAX_HEAD_DIM, "Max head dim exceeded");
|
||||
|
||||
gated_delta_net_fused_op_kernel<<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const float*) mixed_qkvz.data_ptr(),
|
||||
(const float*) mixed_ba.data_ptr(),
|
||||
(const bfloat16*) dt_bias.data_ptr(),
|
||||
(const bfloat16*) a_log.data_ptr(),
|
||||
(bfloat16*) mixed_qkv.data_ptr(),
|
||||
(bfloat16*) z.data_ptr(),
|
||||
(bfloat16*) beta.data_ptr(),
|
||||
(float*) g.data_ptr(),
|
||||
B, S, Nk, Ng, Hk, Hv
|
||||
);
|
||||
#define KERNEL_ARGS \
|
||||
(const float*) mixed_qkvz.data_ptr(), \
|
||||
(const float*) mixed_ba.data_ptr(), \
|
||||
(const bfloat16*) dt_bias.data_ptr(), \
|
||||
(const bfloat16*) a_log.data_ptr(), \
|
||||
(bfloat16*) mixed_qkv.data_ptr(), \
|
||||
(bfloat16*) z.data_ptr(), \
|
||||
(bfloat16*) beta.data_ptr(), \
|
||||
(float*) g.data_ptr(), \
|
||||
B, S, Nk, Ng, Hk, Hv, \
|
||||
beta_scale
|
||||
|
||||
if (threads <= 128)
|
||||
gated_delta_net_fused_op_kernel<128><<<blocks, threads, 0, stream>>>(KERNEL_ARGS);
|
||||
else if (threads <= 256)
|
||||
gated_delta_net_fused_op_kernel<256><<<blocks, threads, 0, stream>>>(KERNEL_ARGS);
|
||||
else TORCH_CHECK(false, "Max head dim exceeded");
|
||||
|
||||
#undef KERNEL_ARGS
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
@@ -236,7 +243,8 @@ __global__ void gated_delta_net_fused_op_2_kernel
|
||||
int B,
|
||||
int S,
|
||||
int H,
|
||||
int rows_per_block
|
||||
int rows_per_block,
|
||||
const float beta_scale
|
||||
)
|
||||
{
|
||||
int t = threadIdx.x % H;
|
||||
@@ -250,7 +258,7 @@ __global__ void gated_delta_net_fused_op_2_kernel
|
||||
out_beta += row * H + t;
|
||||
out_g += row * H + t;
|
||||
|
||||
float beta = _sigmoid_fast_exp(*in_b);
|
||||
float beta = _sigmoid_fast_exp(*in_b) * beta_scale;
|
||||
float dt_bias = as_float(*in_dt_bias);
|
||||
float g = -softplus(*in_a + dt_bias) * __expf(as_float(*in_a_log));
|
||||
|
||||
@@ -270,7 +278,8 @@ void gated_delta_net_fused_op_2
|
||||
const at::Tensor& dt_bias, // [H] bfloat16
|
||||
const at::Tensor& a_log, // [H] float
|
||||
at::Tensor& beta, // out [B,S,H] bfloat16
|
||||
at::Tensor& g // out [B,S,H] float
|
||||
at::Tensor& g, // out [B,S,H] float
|
||||
const float beta_scale
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(b.device());
|
||||
@@ -309,7 +318,8 @@ void gated_delta_net_fused_op_2
|
||||
B, \
|
||||
S, \
|
||||
H, \
|
||||
rows_per_block
|
||||
rows_per_block, \
|
||||
beta_scale
|
||||
|
||||
if (a_log_fp32)
|
||||
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>(ARGS(float));
|
||||
@@ -323,7 +333,8 @@ void gated_delta_net_fused_op_2
|
||||
}
|
||||
|
||||
|
||||
__global__ __launch_bounds__(R_THREADS * SUBK)
|
||||
template <int MAX_HEAD_DIM>
|
||||
__global__ __launch_bounds__(MAX_HEAD_DIM * SUBK)
|
||||
void cuda_recurrent_gated_delta_rule_kernel
|
||||
(
|
||||
// k_dim = num_k_heads * k_head_dim
|
||||
@@ -435,10 +446,10 @@ void cuda_recurrent_gated_delta_rule_kernel
|
||||
float* rs_rd = gl_rs + t + bt * bts * v_head_dim;
|
||||
|
||||
// TODO: Could use tensor cores
|
||||
for (int i = 0; i < k_head_dim / 16 / SUBK; ++i)
|
||||
for (int i = 0; i < k_head_dim / 8 / SUBK; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; ++j, rs_rd += v_head_dim, sh_k_rd++)
|
||||
for (int j = 0; j < 8; ++j, rs_rd += v_head_dim, sh_k_rd++)
|
||||
sum = sum + *sh_k_rd * *rs_rd;
|
||||
}
|
||||
atomicAdd(sh_dot1 + t, sum);
|
||||
@@ -462,10 +473,10 @@ void cuda_recurrent_gated_delta_rule_kernel
|
||||
float* rs_rw = gl_rs + t + bt * bts * v_head_dim;
|
||||
|
||||
// TODO: Could use tensor cores
|
||||
for (int i = 0; i < k_head_dim / 16 / SUBK; ++i)
|
||||
for (int i = 0; i < k_head_dim / 8 / SUBK; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 16; ++j, rs_rw += v_head_dim, sh_k_rd++, sh_q_rd++)
|
||||
for (int j = 0; j < 8; ++j, rs_rw += v_head_dim, sh_k_rd++, sh_q_rd++)
|
||||
{
|
||||
// State update step, k x v
|
||||
float state = *rs_rw;
|
||||
@@ -518,11 +529,6 @@ void cuda_recurrent_gated_delta_rule
|
||||
int seqlen = mixed_qkv.size(1);
|
||||
int qkv_dim = mixed_qkv.size(2);
|
||||
|
||||
TORCH_CHECK(num_k_heads <= MAX_K_HEADS, "num_k_heads > MAX_K_HEADS");
|
||||
TORCH_CHECK(num_v_heads <= MAX_V_HEADS, "num_v_heads > MAX_V_HEADS");
|
||||
TORCH_CHECK(k_head_dim <= MAX_HEAD_DIM, "k_head_dim > MAX_HEAD_DIM");
|
||||
TORCH_CHECK(v_head_dim <= MAX_HEAD_DIM, "v_head_dim > MAX_HEAD_DIM");
|
||||
|
||||
TORCH_CHECK_DTYPE(mixed_qkv, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kFloat);
|
||||
TORCH_CHECK_DTYPE(beta, kBFloat16);
|
||||
@@ -534,19 +540,25 @@ void cuda_recurrent_gated_delta_rule
|
||||
|
||||
float scale = 1.0f / sqrtf(k_head_dim);
|
||||
|
||||
cuda_recurrent_gated_delta_rule_kernel<<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) mixed_qkv.data_ptr(),
|
||||
(const float*) g.data_ptr(),
|
||||
(const bfloat16*) beta.data_ptr(),
|
||||
(float*) recurrent_state.data_ptr(),
|
||||
(bfloat16*) core_attn_out.data_ptr(),
|
||||
bsz,
|
||||
seqlen,
|
||||
num_k_heads,
|
||||
num_v_heads,
|
||||
k_head_dim,
|
||||
v_head_dim,
|
||||
#define KERNEL_ARGS \
|
||||
(const bfloat16*) mixed_qkv.data_ptr(), \
|
||||
(const float*) g.data_ptr(), \
|
||||
(const bfloat16*) beta.data_ptr(), \
|
||||
(float*) recurrent_state.data_ptr(), \
|
||||
(bfloat16*) core_attn_out.data_ptr(), \
|
||||
bsz, \
|
||||
seqlen, \
|
||||
num_k_heads, \
|
||||
num_v_heads, \
|
||||
k_head_dim, \
|
||||
v_head_dim, \
|
||||
scale
|
||||
);
|
||||
|
||||
if (threads.x <= 128)
|
||||
cuda_recurrent_gated_delta_rule_kernel<128><<<blocks, threads, 0, stream>>>(KERNEL_ARGS);
|
||||
else if (threads.x <= 256)
|
||||
cuda_recurrent_gated_delta_rule_kernel<256><<<blocks, threads, 0, stream>>>(KERNEL_ARGS);
|
||||
else TORCH_CHECK(false, "Max head dim exceeded");
|
||||
|
||||
#undef KERNEL_ARGS
|
||||
}
|
||||
|
||||
@@ -28,7 +28,8 @@ at::Tensor BC_GatedDeltaNet::run_bsz1_a
|
||||
num_k_heads,
|
||||
num_v_heads,
|
||||
k_head_dim,
|
||||
v_head_dim
|
||||
v_head_dim,
|
||||
beta_scale
|
||||
);
|
||||
|
||||
return mixed_qkv;
|
||||
|
||||
@@ -31,6 +31,7 @@ struct BC_GatedDeltaNet
|
||||
c10::optional<at::Tensor> conv1d_bias;
|
||||
std::shared_ptr<BC_GatedRMSNorm> norm;
|
||||
std::shared_ptr<BC_LinearEXL3> o_proj;
|
||||
const float beta_scale;
|
||||
|
||||
BC_GatedDeltaNet
|
||||
(
|
||||
@@ -55,7 +56,8 @@ struct BC_GatedDeltaNet
|
||||
at::Tensor _conv1d_weight,
|
||||
c10::optional<at::Tensor> _conv1d_bias,
|
||||
std::shared_ptr<BC_GatedRMSNorm> _norm,
|
||||
std::shared_ptr<BC_LinearEXL3> _o_proj
|
||||
std::shared_ptr<BC_LinearEXL3> _o_proj,
|
||||
const float _beta_scale
|
||||
) :
|
||||
mixed_qkv (std::move(_mixed_qkv)),
|
||||
z (std::move(_z)),
|
||||
@@ -78,7 +80,8 @@ struct BC_GatedDeltaNet
|
||||
conv1d_weight (std::move(_conv1d_weight)),
|
||||
conv1d_bias (std::move(_conv1d_bias)),
|
||||
norm (_norm),
|
||||
o_proj (_o_proj)
|
||||
o_proj (_o_proj),
|
||||
beta_scale (_beta_scale)
|
||||
{}
|
||||
|
||||
at::Tensor run_bsz1_a
|
||||
|
||||
@@ -22,7 +22,8 @@ py::class_<BC_GatedDeltaNet, std::shared_ptr<BC_GatedDeltaNet>>(m, "BC_GatedDelt
|
||||
at::Tensor,
|
||||
c10::optional<at::Tensor>,
|
||||
std::shared_ptr<BC_GatedRMSNorm>,
|
||||
std::shared_ptr<BC_LinearEXL3>
|
||||
std::shared_ptr<BC_LinearEXL3>,
|
||||
float
|
||||
>(),
|
||||
py::arg("mixed_qkv"),
|
||||
py::arg("z"),
|
||||
@@ -45,7 +46,8 @@ py::class_<BC_GatedDeltaNet, std::shared_ptr<BC_GatedDeltaNet>>(m, "BC_GatedDelt
|
||||
py::arg("conv1d_weight"),
|
||||
py::arg("conv1d_bias"),
|
||||
py::arg("norm"),
|
||||
py::arg("o_proj")
|
||||
py::arg("o_proj"),
|
||||
py::arg("beta_scale")
|
||||
)
|
||||
.def("run_bsz1_a", &BC_GatedDeltaNet::run_bsz1_a)
|
||||
.def("run_bsz1_b", &BC_GatedDeltaNet::run_bsz1_b);
|
||||
|
||||
@@ -280,6 +280,7 @@ class GatedDeltaNet(Module):
|
||||
num_v_heads: int,
|
||||
rms_norm_eps: float,
|
||||
conv_kernel_size: int,
|
||||
beta_scale: float = 1.0,
|
||||
key_a_log: str | None = None,
|
||||
key_dt_bias: str | None = None,
|
||||
key_conv1d: str | None = None,
|
||||
@@ -308,6 +309,7 @@ class GatedDeltaNet(Module):
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.k_dim = self.k_head_dim * self.num_k_heads
|
||||
self.v_dim = self.v_head_dim * self.num_v_heads
|
||||
self.beta_scale = beta_scale
|
||||
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
@@ -437,7 +439,8 @@ class GatedDeltaNet(Module):
|
||||
self.conv1d_weight,
|
||||
self.conv1d_bias,
|
||||
self.norm.bc,
|
||||
self.o_proj.inner.bc
|
||||
self.o_proj.inner.bc,
|
||||
self.beta_scale
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -575,7 +578,8 @@ class GatedDeltaNet(Module):
|
||||
self.num_k_heads,
|
||||
self.num_v_heads,
|
||||
self.k_head_dim,
|
||||
self.v_head_dim
|
||||
self.v_head_dim,
|
||||
self.beta_scale
|
||||
)
|
||||
else:
|
||||
# TODO: Bound class and/or graph for this part
|
||||
@@ -593,7 +597,8 @@ class GatedDeltaNet(Module):
|
||||
b, a,
|
||||
self.dt_bias,
|
||||
self.a_log,
|
||||
beta, g
|
||||
beta, g,
|
||||
self.beta_scale
|
||||
)
|
||||
|
||||
# Convolution
|
||||
|
||||
Reference in New Issue
Block a user