GatedDeltaNet: Fused kernel for splitting inputs, casting, applying sigmoid etc.

This commit is contained in:
turboderp
2025-09-21 05:02:00 +02:00
parent 1f6b3b5c0a
commit 8c71b0aa57
5 changed files with 281 additions and 32 deletions

View File

@@ -189,3 +189,4 @@ This project owes its existence to a wonderful community of FOSS developers and
- [QTIP](https://github.com/Cornell-RelaxML/qtip)
- [Transformers](https://github.com/huggingface/transformers)
- [Marlin](https://github.com/IST-DASLab/marlin)
- [Flash Linear Attention](https://github.com/fla-org/flash-linear-attention)

View File

@@ -13,6 +13,7 @@
#include "activation.cuh"
#include "softcap.cuh"
#include "routing.cuh"
#include "gdn.cuh"
#include "quant/quantize.cuh"
#include "quant/pack.cuh"

View File

@@ -0,0 +1,15 @@
void gated_delta_net_fused_op
(
const at::Tensor& mixed_qkvz,
const at::Tensor& mixed_ba,
const at::Tensor& dt_bias,
const at::Tensor& a_log,
at::Tensor& mixed_qkv,
at::Tensor& z,
at::Tensor& beta,
at::Tensor& g,
size_t num_k_heads,
size_t num_v_heads,
size_t k_head_dim,
size_t v_head_dim
);

View File

@@ -0,0 +1,202 @@
#include <cuda_fp16.h>
#include <cuda_fp16.hpp>
#include "activation.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include "util.h"
#include "util.cuh"
#include "compat.cuh"
#include <cmath>
__device__ __forceinline__ float _sigmoid_fast_exp(float x)
{
return 1.0f / (1.0f + __expf(-x));
}
__device__ __forceinline__ __nv_bfloat16 trunc_bf16(float x)
{
return __float2bfloat16_rn(x);
}
__device__ __forceinline__ float untrunc_bf16(__nv_bfloat16 x)
{
return __bfloat162float(x);
}
__device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshold=20.0
{
if (x > 20.0f) return x;
return log1pf(__expf(x));
}
__global__ void gated_delta_net_fused_op_kernel
(
const float* __restrict__ in_qkvz, // [B,S,Nk, Fseg], float32
const float* __restrict__ in_ba, // [B,S,Nk, 2*Ng], float32
const __nv_bfloat16* __restrict__ dt_bias, // [Nv], bfloat16
const __nv_bfloat16* __restrict__ a_log, // [Nv], bfloat16
__nv_bfloat16* __restrict__ out_qkv, // [B, 2*Nk*Hk + Nv*Hv, S], bfloat16
__nv_bfloat16* __restrict__ out_z, // [B, S, Nv, Hv], bfloat16
__nv_bfloat16* __restrict__ out_beta, // [B, S, Nv], bfloat16
float* __restrict__ out_g, // [B, S, Nv], float32
size_t B,
size_t S,
size_t Nk,
size_t Ng,
size_t Hk,
size_t Hv
){
const size_t Nv = Nk * Ng;
const size_t Fseg = 2 * Hk + 2 * Ng * Hv; // per-khead segment in mixed_qkvz
const size_t Fba = 2 * Ng; // per-khead segment in mixed_ba
const size_t Nlin = B * S * Nk;
const size_t Fout = 2 * Nk * Hk + Nv * Hv; // feature dim in mixed_qkv
int t = threadIdx.x;
for (size_t linear = blockIdx.x; linear < Nlin; linear += (size_t) gridDim.x)
{
size_t kh = linear % Nk;
size_t s = (linear / Nk) % S;
size_t b = (linear / Nk) / S;
// Base offsets into inputs for this (b,s,kh)
const size_t base_qkvz = (((b * S) + s) * Nk + kh) * Fseg;
const size_t base_ba = (((b * S) + s) * Nk + kh) * Fba;
// q block: length Hk, source offset 0..Hk-1
// feature range in out_qkv: [kh*Hk, kh*Hk + Hk)
const size_t q_feat0 = kh * Hk;
if (t < Hk)
{
const float vq = in_qkvz[base_qkvz + t];
const size_t f = q_feat0 + t; // feature index in [0 .. Nk*Hk)
const size_t out_off = ((b * Fout) + f) * S + s;
out_qkv[out_off] = trunc_bf16(vq);
}
// k block: length Hk, source offset [Hk .. 2*Hk)
// feature range in out_qkv: [Nk*Hk + kh*Hk, Nk*Hk + kh*Hk + Hk)
const size_t k_in0 = Hk;
const size_t k_feat0 = Nk*Hk + kh*Hk;
if (t < Hk)
{
const float vk = in_qkvz[base_qkvz + k_in0 + t];
const size_t f = k_feat0 + t;
const size_t out_off = ((b * Fout) + f) * S + s;
out_qkv[out_off] = trunc_bf16(vk);
}
// v and z blocks: each length Ng*Hv
// v source offset: [2*Hk .. 2*Hk + Ng*Hv)
// z source offset: [2*Hk + Ng*Hv .. 2*Hk + 2*Ng*Hv)
const size_t v_in0 = 2*Hk;
const size_t z_in0 = 2*Hk + Ng*Hv;
const size_t v_feat_base = 2*Nk*Hk; // start of v block in feature dim
if (t < Hv)
{
for (size_t g = 0; g < Ng; ++g)
{
const size_t vhead = kh * Ng + g; // global v-head index in [0..Nv)
// v -> out_qkv (feature block)
const float vv = in_qkvz[base_qkvz + v_in0 + g*Hv + t];
const size_t f = v_feat_base + vhead*Hv + t;
const size_t out_v_off = ((b * (size_t)Fout) + f) * S + s;
out_qkv[out_v_off] = trunc_bf16(vv);
// z -> out_z
const float vz = in_qkvz[base_qkvz + z_in0 + g*Hv + t];
const size_t out_z_off = ((((b * S) + s) * Nv) + vhead) * Hv + t;
out_z[out_z_off] = trunc_bf16(vz);
}
}
// b and a from mixed_ba (each Ng long) -> [B,S,Nv]
if (t < Ng)
{
const size_t vhead = kh * Ng + t;
const size_t out_va_off = ((b * S) + s) * Nv + vhead;
// beta = sigmoid(b).bfloat16()
float b = in_ba[base_ba + t];
out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b));
// g = -self.a_log.float().exp() * F.softplus(a + self.dt_bias.float())
float g = in_ba[base_ba + Ng + t];
float bi = untrunc_bf16(dt_bias[out_va_off % Nv]);
float al = untrunc_bf16(a_log[out_va_off % Nv]);
out_g[out_va_off] = -softplus(g + bi) * expf(al);
}
}
}
/*
Single kernel for splitting projected qkvz + ba GDN inputs and producing gate + beta tensors
Also downcasts from float32 to bfloat16
*/
void gated_delta_net_fused_op
(
const at::Tensor& mixed_qkvz, // [B,S, Nk*(2*Hk + 2*Ng*Hv)]
const at::Tensor& mixed_ba, // [B,S, Nk*(2*Ng)]
const at::Tensor& dt_bias, // Nv
const at::Tensor& a_log, // Nv
at::Tensor& mixed_qkv, // out [B, 2*Nk*Hk + Nv*Hv, S]
at::Tensor& z, // out [B, S, Nv, Hv]
at::Tensor& beta, // out [B, S, Nv]
at::Tensor& g, // out [B, S, Nv]
size_t num_k_heads,
size_t num_v_heads,
size_t k_head_dim,
size_t v_head_dim
)
{
const at::cuda::OptionalCUDAGuard device_guard(mixed_qkvz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
const auto B = mixed_qkvz.size(0);
const auto S = mixed_qkvz.size(1);
const auto Nk = num_k_heads;
const auto Hk = k_head_dim;
const auto Hv = v_head_dim;
const auto Nv = num_v_heads;
TORCH_CHECK(Nk > 0 && Nv > 0 && Hk > 0 && Hv > 0, "invalid sizes");
TORCH_CHECK(Nv % Nk == 0, "num_v_heads must be divisible by num_k_heads");
const size_t Ng = Nv / Nk;
const size_t Fseg = 2*Hk + 2*Ng*Hv;
TORCH_CHECK(mixed_qkvz.size(2) == Nk * Fseg, "mixed_qkvz last dim should be Nk*(2*Hk + 2*Ng*Hv)");
TORCH_CHECK(mixed_ba.size(2) == Nk * (2*Ng), "mixed_ba last dim should be Nk*(2*Ng)");
TORCH_CHECK(mixed_qkv.size(1) == 2*Nk*Hk + Nv*Hv, "mixed_qkv must be [B, 2*Nk*Hk + Nv*Hv, S]");
TORCH_CHECK(mixed_qkv.size(2) == S, "mixed_qkv must be [B, 2*Nk*Hk + Nv*Hv, S]");
TORCH_CHECK_DTYPE(mixed_qkvz, kFloat);
TORCH_CHECK_DTYPE(mixed_ba, kFloat);
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
TORCH_CHECK_DTYPE(a_log, kBFloat16);
TORCH_CHECK_DTYPE(mixed_qkv, kBFloat16);
TORCH_CHECK_DTYPE(z, kBFloat16);
TORCH_CHECK_DTYPE(beta, kBFloat16);
TORCH_CHECK_DTYPE(g, kFloat);
const int blocks = B * S * Nk;
const int threads = MAX(Hk, Hv);
gated_delta_net_fused_op_kernel<<<blocks, threads, 0, stream>>>
(
(const float*) mixed_qkvz.data_ptr(),
(const float*) mixed_ba.data_ptr(),
(const __nv_bfloat16*) dt_bias.data_ptr(),
(const __nv_bfloat16*) a_log.data_ptr(),
(__nv_bfloat16*) mixed_qkv.data_ptr(),
(__nv_bfloat16*) z.data_ptr(),
(__nv_bfloat16*) beta.data_ptr(),
(float*) g.data_ptr(),
B, S, Nk, Ng, Hk, Hv
);
cuda_check(cudaPeekAtLastError());
}

View File

@@ -1,7 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass
from typing_extensions import override
import torch
import torch.nn.functional as F
@@ -203,11 +201,14 @@ class GatedDeltaNet(Module):
self.num_v_groups = num_v_heads // num_k_heads
self.rms_norm_eps = rms_norm_eps
self.conv_kernel_size = conv_kernel_size
self.k_dim = self.k_head_dim * self.num_k_heads
self.v_dim = self.v_head_dim * self.num_v_heads
self.out_dtype = out_dtype
fdim_qkvz = 2 * self.num_k_heads * self.k_head_dim + 2 * self.v_head_dim * self.num_v_heads
fdim_qkvz = 2 * self.num_k_heads * self.k_head_dim + 2 * self.num_v_heads * self.v_head_dim
fdim_ba = 2 * self.num_v_heads
self.fdim_qkv = 2 * self.num_k_heads * self.k_head_dim + self.num_v_heads * self.v_head_dim
if key_fused_qkvz:
self.qkvz_proj = Linear(config, f"{key}.{key_fused_qkvz}", hidden_size, fdim_qkvz, qmap = qmap + ".input", out_dtype = torch.float)
@@ -243,6 +244,7 @@ class GatedDeltaNet(Module):
"recurrent_cache": True
})
self.prealloc_split = None
# self.cache_layers = []
# self.tp_cache_lookup = {}
# self.multi_kv = None
@@ -270,6 +272,13 @@ class GatedDeltaNet(Module):
self.norm.load(device, **kwargs)
self.load_local(device, **kwargs)
# Preallocate (mixed_qkv, z, beta, g) tensors for bsz 1, seqlen 1
self.prealloc_split = (
torch.zeros((1, self.fdim_qkv, 1), dtype = torch.bfloat16, device = device),
torch.zeros((1, 1, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = device),
torch.zeros((1, 1, self.num_v_heads), dtype = torch.bfloat16, device = device),
torch.zeros((1, 1, self.num_v_heads), dtype = torch.float, device = device)
)
@override
def unload(self):
@@ -283,30 +292,45 @@ class GatedDeltaNet(Module):
def split_fused_inputs(self, mixed_qkvz, mixed_ba):
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
# mixed_qkvz and mixed_ba have same (bsz, seqlen)
# both are contiguous
bsz, seqlen, _ = mixed_qkvz.shape
mixed_qkvz = mixed_qkvz.view(
bsz,
seqlen,
self.num_k_heads,
2 * self.k_head_dim + 2 * self.v_head_dim * self.num_v_heads // self.num_k_heads,
)
new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
mixed_ba = mixed_ba.view(
bsz,
seqlen,
self.num_k_heads,
2 * self.num_v_heads // self.num_k_heads
)
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
split_arg_list_qkvz = [
self.k_head_dim,
self.k_head_dim,
(self.num_v_groups * self.v_head_dim),
(self.num_v_groups * self.v_head_dim),
]
split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim = 3)
split_arg_list_ba = [
self.num_v_heads // self.num_k_heads,
self.num_v_heads // self.num_k_heads
]
q, k, v, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim = 3)
b, a = torch.split(mixed_ba, split_arg_list_ba, dim = 3)
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
value = value.reshape(value.size(0), value.size(1), -1, self.v_head_dim)
z = z.reshape(z.size(0), z.size(1), -1, self.v_head_dim)
b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
return query, key, value, z, b, a
q = q.reshape(bsz, seqlen, -1)
k = k.reshape(bsz, seqlen, -1)
v = v.reshape(bsz, seqlen, -1)
z = z.reshape(bsz, seqlen, -1, self.v_head_dim)
b = b.reshape(bsz, seqlen, self.num_v_heads)
a = a.reshape(bsz, seqlen, self.num_v_heads)
mixed_qkv = torch.cat((q, k, v), dim = -1)
mixed_qkv = mixed_qkv.transpose(1, 2)
return mixed_qkv, z, b, a
@override
@@ -336,16 +360,29 @@ class GatedDeltaNet(Module):
save_state = False
# Projections
qkvz = self.qkvz_proj.forward(x, params).to(torch.bfloat16)
ba = self.ba_proj.forward(x, params).to(torch.bfloat16)
q, k, v, z, b, a = self.split_fused_inputs(qkvz, ba)
q = q.reshape(bsz, seqlen, -1)
k = k.reshape(bsz, seqlen, -1)
v = v.reshape(bsz, seqlen, -1)
qkvz = self.qkvz_proj.forward(x, params)
ba = self.ba_proj.forward(x, params)
if bsz == 1 and seqlen == 1:
mixed_qkv, z, beta, g = self.prealloc_split
else:
mixed_qkv = torch.zeros((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
z = torch.zeros((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
beta = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
g = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
ext.gated_delta_net_fused_op(
qkvz, ba,
self.dt_bias,
self.a_log,
mixed_qkv, z, beta, g,
self.num_k_heads,
self.num_v_heads,
self.k_head_dim,
self.v_head_dim
)
# Convolution
mixed_qkv = torch.cat((q, k, v), dim = -1).transpose(1, 2)
if conv_state is None:
if save_state:
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
@@ -374,18 +411,11 @@ class GatedDeltaNet(Module):
mixed_qkv = mixed_qkv.transpose(1, 2)
# Gate
k_dim = self.k_head_dim * self.num_k_heads
v_dim = self.v_head_dim * self.num_v_heads
q, k, v = torch.split(mixed_qkv, [k_dim, k_dim, v_dim], dim = -1)
q, k, v = torch.split(mixed_qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1)
q = q.view(bsz, seqlen, -1, self.k_head_dim)
k = k.view(bsz, seqlen, -1, self.k_head_dim)
v = v.view(bsz, seqlen, -1, self.v_head_dim)
if self.a_log_f_exp is None:
self.a_log_f_exp = -self.a_log.float().exp()
g = self.a_log_f_exp * F.softplus(a.float() + self.dt_bias)
beta = b.sigmoid()
# Grouped attn
if self.num_v_heads // self.num_k_heads > 1:
q = q.repeat_interleave(self.num_v_groups, dim = 2)
@@ -398,7 +428,7 @@ class GatedDeltaNet(Module):
g = g,
beta = beta,
initial_state = recurrent_state,
output_final_state = True, # cache_params is not None,
output_final_state = save_state,
use_qk_l2norm_in_kernel = True,
)
else: