mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
GatedDeltaNet: Fused kernel for splitting inputs, casting, applying sigmoid etc.
This commit is contained in:
@@ -189,3 +189,4 @@ This project owes its existence to a wonderful community of FOSS developers and
|
||||
- [QTIP](https://github.com/Cornell-RelaxML/qtip)
|
||||
- [Transformers](https://github.com/huggingface/transformers)
|
||||
- [Marlin](https://github.com/IST-DASLab/marlin)
|
||||
- [Flash Linear Attention](https://github.com/fla-org/flash-linear-attention)
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "activation.cuh"
|
||||
#include "softcap.cuh"
|
||||
#include "routing.cuh"
|
||||
#include "gdn.cuh"
|
||||
|
||||
#include "quant/quantize.cuh"
|
||||
#include "quant/pack.cuh"
|
||||
|
||||
15
exllamav3/exllamav3_ext/gdn.cuh
Normal file
15
exllamav3/exllamav3_ext/gdn.cuh
Normal file
@@ -0,0 +1,15 @@
|
||||
void gated_delta_net_fused_op
|
||||
(
|
||||
const at::Tensor& mixed_qkvz,
|
||||
const at::Tensor& mixed_ba,
|
||||
const at::Tensor& dt_bias,
|
||||
const at::Tensor& a_log,
|
||||
at::Tensor& mixed_qkv,
|
||||
at::Tensor& z,
|
||||
at::Tensor& beta,
|
||||
at::Tensor& g,
|
||||
size_t num_k_heads,
|
||||
size_t num_v_heads,
|
||||
size_t k_head_dim,
|
||||
size_t v_head_dim
|
||||
);
|
||||
202
exllamav3/exllamav3_ext/gnd.cu
Normal file
202
exllamav3/exllamav3_ext/gnd.cu
Normal file
@@ -0,0 +1,202 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp16.hpp>
|
||||
#include "activation.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "util.h"
|
||||
#include "util.cuh"
|
||||
#include "compat.cuh"
|
||||
#include <cmath>
|
||||
|
||||
__device__ __forceinline__ float _sigmoid_fast_exp(float x)
|
||||
{
|
||||
return 1.0f / (1.0f + __expf(-x));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ __nv_bfloat16 trunc_bf16(float x)
|
||||
{
|
||||
return __float2bfloat16_rn(x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float untrunc_bf16(__nv_bfloat16 x)
|
||||
{
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshold=20.0
|
||||
{
|
||||
if (x > 20.0f) return x;
|
||||
return log1pf(__expf(x));
|
||||
}
|
||||
|
||||
__global__ void gated_delta_net_fused_op_kernel
|
||||
(
|
||||
const float* __restrict__ in_qkvz, // [B,S,Nk, Fseg], float32
|
||||
const float* __restrict__ in_ba, // [B,S,Nk, 2*Ng], float32
|
||||
const __nv_bfloat16* __restrict__ dt_bias, // [Nv], bfloat16
|
||||
const __nv_bfloat16* __restrict__ a_log, // [Nv], bfloat16
|
||||
__nv_bfloat16* __restrict__ out_qkv, // [B, 2*Nk*Hk + Nv*Hv, S], bfloat16
|
||||
__nv_bfloat16* __restrict__ out_z, // [B, S, Nv, Hv], bfloat16
|
||||
__nv_bfloat16* __restrict__ out_beta, // [B, S, Nv], bfloat16
|
||||
float* __restrict__ out_g, // [B, S, Nv], float32
|
||||
size_t B,
|
||||
size_t S,
|
||||
size_t Nk,
|
||||
size_t Ng,
|
||||
size_t Hk,
|
||||
size_t Hv
|
||||
){
|
||||
const size_t Nv = Nk * Ng;
|
||||
const size_t Fseg = 2 * Hk + 2 * Ng * Hv; // per-khead segment in mixed_qkvz
|
||||
const size_t Fba = 2 * Ng; // per-khead segment in mixed_ba
|
||||
const size_t Nlin = B * S * Nk;
|
||||
const size_t Fout = 2 * Nk * Hk + Nv * Hv; // feature dim in mixed_qkv
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
for (size_t linear = blockIdx.x; linear < Nlin; linear += (size_t) gridDim.x)
|
||||
{
|
||||
size_t kh = linear % Nk;
|
||||
size_t s = (linear / Nk) % S;
|
||||
size_t b = (linear / Nk) / S;
|
||||
|
||||
// Base offsets into inputs for this (b,s,kh)
|
||||
const size_t base_qkvz = (((b * S) + s) * Nk + kh) * Fseg;
|
||||
const size_t base_ba = (((b * S) + s) * Nk + kh) * Fba;
|
||||
|
||||
// q block: length Hk, source offset 0..Hk-1
|
||||
// feature range in out_qkv: [kh*Hk, kh*Hk + Hk)
|
||||
const size_t q_feat0 = kh * Hk;
|
||||
if (t < Hk)
|
||||
{
|
||||
const float vq = in_qkvz[base_qkvz + t];
|
||||
const size_t f = q_feat0 + t; // feature index in [0 .. Nk*Hk)
|
||||
const size_t out_off = ((b * Fout) + f) * S + s;
|
||||
out_qkv[out_off] = trunc_bf16(vq);
|
||||
}
|
||||
|
||||
// k block: length Hk, source offset [Hk .. 2*Hk)
|
||||
// feature range in out_qkv: [Nk*Hk + kh*Hk, Nk*Hk + kh*Hk + Hk)
|
||||
const size_t k_in0 = Hk;
|
||||
const size_t k_feat0 = Nk*Hk + kh*Hk;
|
||||
if (t < Hk)
|
||||
{
|
||||
const float vk = in_qkvz[base_qkvz + k_in0 + t];
|
||||
const size_t f = k_feat0 + t;
|
||||
const size_t out_off = ((b * Fout) + f) * S + s;
|
||||
out_qkv[out_off] = trunc_bf16(vk);
|
||||
}
|
||||
|
||||
// v and z blocks: each length Ng*Hv
|
||||
// v source offset: [2*Hk .. 2*Hk + Ng*Hv)
|
||||
// z source offset: [2*Hk + Ng*Hv .. 2*Hk + 2*Ng*Hv)
|
||||
const size_t v_in0 = 2*Hk;
|
||||
const size_t z_in0 = 2*Hk + Ng*Hv;
|
||||
const size_t v_feat_base = 2*Nk*Hk; // start of v block in feature dim
|
||||
|
||||
if (t < Hv)
|
||||
{
|
||||
for (size_t g = 0; g < Ng; ++g)
|
||||
{
|
||||
const size_t vhead = kh * Ng + g; // global v-head index in [0..Nv)
|
||||
|
||||
// v -> out_qkv (feature block)
|
||||
const float vv = in_qkvz[base_qkvz + v_in0 + g*Hv + t];
|
||||
const size_t f = v_feat_base + vhead*Hv + t;
|
||||
const size_t out_v_off = ((b * (size_t)Fout) + f) * S + s;
|
||||
out_qkv[out_v_off] = trunc_bf16(vv);
|
||||
|
||||
// z -> out_z
|
||||
const float vz = in_qkvz[base_qkvz + z_in0 + g*Hv + t];
|
||||
const size_t out_z_off = ((((b * S) + s) * Nv) + vhead) * Hv + t;
|
||||
out_z[out_z_off] = trunc_bf16(vz);
|
||||
}
|
||||
}
|
||||
|
||||
// b and a from mixed_ba (each Ng long) -> [B,S,Nv]
|
||||
if (t < Ng)
|
||||
{
|
||||
const size_t vhead = kh * Ng + t;
|
||||
const size_t out_va_off = ((b * S) + s) * Nv + vhead;
|
||||
|
||||
// beta = sigmoid(b).bfloat16()
|
||||
float b = in_ba[base_ba + t];
|
||||
out_beta[out_va_off] = trunc_bf16(_sigmoid_fast_exp(b));
|
||||
|
||||
// g = -self.a_log.float().exp() * F.softplus(a + self.dt_bias.float())
|
||||
float g = in_ba[base_ba + Ng + t];
|
||||
float bi = untrunc_bf16(dt_bias[out_va_off % Nv]);
|
||||
float al = untrunc_bf16(a_log[out_va_off % Nv]);
|
||||
out_g[out_va_off] = -softplus(g + bi) * expf(al);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Single kernel for splitting projected qkvz + ba GDN inputs and producing gate + beta tensors
|
||||
Also downcasts from float32 to bfloat16
|
||||
*/
|
||||
|
||||
void gated_delta_net_fused_op
|
||||
(
|
||||
const at::Tensor& mixed_qkvz, // [B,S, Nk*(2*Hk + 2*Ng*Hv)]
|
||||
const at::Tensor& mixed_ba, // [B,S, Nk*(2*Ng)]
|
||||
const at::Tensor& dt_bias, // Nv
|
||||
const at::Tensor& a_log, // Nv
|
||||
at::Tensor& mixed_qkv, // out [B, 2*Nk*Hk + Nv*Hv, S]
|
||||
at::Tensor& z, // out [B, S, Nv, Hv]
|
||||
at::Tensor& beta, // out [B, S, Nv]
|
||||
at::Tensor& g, // out [B, S, Nv]
|
||||
size_t num_k_heads,
|
||||
size_t num_v_heads,
|
||||
size_t k_head_dim,
|
||||
size_t v_head_dim
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(mixed_qkvz.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
const auto B = mixed_qkvz.size(0);
|
||||
const auto S = mixed_qkvz.size(1);
|
||||
const auto Nk = num_k_heads;
|
||||
const auto Hk = k_head_dim;
|
||||
const auto Hv = v_head_dim;
|
||||
const auto Nv = num_v_heads;
|
||||
|
||||
TORCH_CHECK(Nk > 0 && Nv > 0 && Hk > 0 && Hv > 0, "invalid sizes");
|
||||
TORCH_CHECK(Nv % Nk == 0, "num_v_heads must be divisible by num_k_heads");
|
||||
const size_t Ng = Nv / Nk;
|
||||
|
||||
const size_t Fseg = 2*Hk + 2*Ng*Hv;
|
||||
TORCH_CHECK(mixed_qkvz.size(2) == Nk * Fseg, "mixed_qkvz last dim should be Nk*(2*Hk + 2*Ng*Hv)");
|
||||
TORCH_CHECK(mixed_ba.size(2) == Nk * (2*Ng), "mixed_ba last dim should be Nk*(2*Ng)");
|
||||
TORCH_CHECK(mixed_qkv.size(1) == 2*Nk*Hk + Nv*Hv, "mixed_qkv must be [B, 2*Nk*Hk + Nv*Hv, S]");
|
||||
TORCH_CHECK(mixed_qkv.size(2) == S, "mixed_qkv must be [B, 2*Nk*Hk + Nv*Hv, S]");
|
||||
|
||||
TORCH_CHECK_DTYPE(mixed_qkvz, kFloat);
|
||||
TORCH_CHECK_DTYPE(mixed_ba, kFloat);
|
||||
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(a_log, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(mixed_qkv, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(z, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(beta, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kFloat);
|
||||
|
||||
const int blocks = B * S * Nk;
|
||||
const int threads = MAX(Hk, Hv);
|
||||
|
||||
gated_delta_net_fused_op_kernel<<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const float*) mixed_qkvz.data_ptr(),
|
||||
(const float*) mixed_ba.data_ptr(),
|
||||
(const __nv_bfloat16*) dt_bias.data_ptr(),
|
||||
(const __nv_bfloat16*) a_log.data_ptr(),
|
||||
(__nv_bfloat16*) mixed_qkv.data_ptr(),
|
||||
(__nv_bfloat16*) z.data_ptr(),
|
||||
(__nv_bfloat16*) beta.data_ptr(),
|
||||
(float*) g.data_ptr(),
|
||||
B, S, Nk, Ng, Hk, Hv
|
||||
);
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
@@ -1,7 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import override
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -203,11 +201,14 @@ class GatedDeltaNet(Module):
|
||||
self.num_v_groups = num_v_heads // num_k_heads
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.k_dim = self.k_head_dim * self.num_k_heads
|
||||
self.v_dim = self.v_head_dim * self.num_v_heads
|
||||
|
||||
self.out_dtype = out_dtype
|
||||
|
||||
fdim_qkvz = 2 * self.num_k_heads * self.k_head_dim + 2 * self.v_head_dim * self.num_v_heads
|
||||
fdim_qkvz = 2 * self.num_k_heads * self.k_head_dim + 2 * self.num_v_heads * self.v_head_dim
|
||||
fdim_ba = 2 * self.num_v_heads
|
||||
self.fdim_qkv = 2 * self.num_k_heads * self.k_head_dim + self.num_v_heads * self.v_head_dim
|
||||
|
||||
if key_fused_qkvz:
|
||||
self.qkvz_proj = Linear(config, f"{key}.{key_fused_qkvz}", hidden_size, fdim_qkvz, qmap = qmap + ".input", out_dtype = torch.float)
|
||||
@@ -243,6 +244,7 @@ class GatedDeltaNet(Module):
|
||||
"recurrent_cache": True
|
||||
})
|
||||
|
||||
self.prealloc_split = None
|
||||
# self.cache_layers = []
|
||||
# self.tp_cache_lookup = {}
|
||||
# self.multi_kv = None
|
||||
@@ -270,6 +272,13 @@ class GatedDeltaNet(Module):
|
||||
self.norm.load(device, **kwargs)
|
||||
self.load_local(device, **kwargs)
|
||||
|
||||
# Preallocate (mixed_qkv, z, beta, g) tensors for bsz 1, seqlen 1
|
||||
self.prealloc_split = (
|
||||
torch.zeros((1, self.fdim_qkv, 1), dtype = torch.bfloat16, device = device),
|
||||
torch.zeros((1, 1, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = device),
|
||||
torch.zeros((1, 1, self.num_v_heads), dtype = torch.bfloat16, device = device),
|
||||
torch.zeros((1, 1, self.num_v_heads), dtype = torch.float, device = device)
|
||||
)
|
||||
|
||||
@override
|
||||
def unload(self):
|
||||
@@ -283,30 +292,45 @@ class GatedDeltaNet(Module):
|
||||
|
||||
|
||||
def split_fused_inputs(self, mixed_qkvz, mixed_ba):
|
||||
new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + (
|
||||
# mixed_qkvz and mixed_ba have same (bsz, seqlen)
|
||||
# both are contiguous
|
||||
bsz, seqlen, _ = mixed_qkvz.shape
|
||||
|
||||
mixed_qkvz = mixed_qkvz.view(
|
||||
bsz,
|
||||
seqlen,
|
||||
self.num_k_heads,
|
||||
2 * self.k_head_dim + 2 * self.v_head_dim * self.num_v_heads // self.num_k_heads,
|
||||
)
|
||||
new_tensor_shape_ba = mixed_ba.size()[:-1] + (self.num_k_heads, 2 * self.num_v_heads // self.num_k_heads)
|
||||
mixed_ba = mixed_ba.view(
|
||||
bsz,
|
||||
seqlen,
|
||||
self.num_k_heads,
|
||||
2 * self.num_v_heads // self.num_k_heads
|
||||
)
|
||||
|
||||
mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz)
|
||||
mixed_ba = mixed_ba.view(*new_tensor_shape_ba)
|
||||
split_arg_list_qkvz = [
|
||||
self.k_head_dim,
|
||||
self.k_head_dim,
|
||||
(self.num_v_groups * self.v_head_dim),
|
||||
(self.num_v_groups * self.v_head_dim),
|
||||
]
|
||||
split_arg_list_ba = [self.num_v_heads // self.num_k_heads, self.num_v_heads // self.num_k_heads]
|
||||
query, key, value, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim = 3)
|
||||
split_arg_list_ba = [
|
||||
self.num_v_heads // self.num_k_heads,
|
||||
self.num_v_heads // self.num_k_heads
|
||||
]
|
||||
q, k, v, z = torch.split(mixed_qkvz, split_arg_list_qkvz, dim = 3)
|
||||
b, a = torch.split(mixed_ba, split_arg_list_ba, dim = 3)
|
||||
|
||||
# [b, sq, ng, np/ng * hn] -> [b, sq, np, hn]
|
||||
value = value.reshape(value.size(0), value.size(1), -1, self.v_head_dim)
|
||||
z = z.reshape(z.size(0), z.size(1), -1, self.v_head_dim)
|
||||
b = b.reshape(b.size(0), b.size(1), self.num_v_heads)
|
||||
a = a.reshape(a.size(0), a.size(1), self.num_v_heads)
|
||||
return query, key, value, z, b, a
|
||||
q = q.reshape(bsz, seqlen, -1)
|
||||
k = k.reshape(bsz, seqlen, -1)
|
||||
v = v.reshape(bsz, seqlen, -1)
|
||||
z = z.reshape(bsz, seqlen, -1, self.v_head_dim)
|
||||
b = b.reshape(bsz, seqlen, self.num_v_heads)
|
||||
a = a.reshape(bsz, seqlen, self.num_v_heads)
|
||||
mixed_qkv = torch.cat((q, k, v), dim = -1)
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
return mixed_qkv, z, b, a
|
||||
|
||||
|
||||
@override
|
||||
@@ -336,16 +360,29 @@ class GatedDeltaNet(Module):
|
||||
save_state = False
|
||||
|
||||
# Projections
|
||||
qkvz = self.qkvz_proj.forward(x, params).to(torch.bfloat16)
|
||||
ba = self.ba_proj.forward(x, params).to(torch.bfloat16)
|
||||
q, k, v, z, b, a = self.split_fused_inputs(qkvz, ba)
|
||||
q = q.reshape(bsz, seqlen, -1)
|
||||
k = k.reshape(bsz, seqlen, -1)
|
||||
v = v.reshape(bsz, seqlen, -1)
|
||||
qkvz = self.qkvz_proj.forward(x, params)
|
||||
ba = self.ba_proj.forward(x, params)
|
||||
|
||||
if bsz == 1 and seqlen == 1:
|
||||
mixed_qkv, z, beta, g = self.prealloc_split
|
||||
else:
|
||||
mixed_qkv = torch.zeros((bsz, self.fdim_qkv, seqlen), dtype = torch.bfloat16, device = self.device)
|
||||
z = torch.zeros((bsz, seqlen, self.num_v_heads, self.v_head_dim), dtype = torch.bfloat16, device = self.device)
|
||||
beta = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.bfloat16, device = self.device)
|
||||
g = torch.zeros((bsz, seqlen, self.num_v_heads), dtype = torch.float, device = self.device)
|
||||
|
||||
ext.gated_delta_net_fused_op(
|
||||
qkvz, ba,
|
||||
self.dt_bias,
|
||||
self.a_log,
|
||||
mixed_qkv, z, beta, g,
|
||||
self.num_k_heads,
|
||||
self.num_v_heads,
|
||||
self.k_head_dim,
|
||||
self.v_head_dim
|
||||
)
|
||||
|
||||
# Convolution
|
||||
mixed_qkv = torch.cat((q, k, v), dim = -1).transpose(1, 2)
|
||||
|
||||
if conv_state is None:
|
||||
if save_state:
|
||||
conv_state = F.pad(mixed_qkv, (self.conv_kernel_size - mixed_qkv.shape[-1], 0))
|
||||
@@ -374,18 +411,11 @@ class GatedDeltaNet(Module):
|
||||
mixed_qkv = mixed_qkv.transpose(1, 2)
|
||||
|
||||
# Gate
|
||||
k_dim = self.k_head_dim * self.num_k_heads
|
||||
v_dim = self.v_head_dim * self.num_v_heads
|
||||
q, k, v = torch.split(mixed_qkv, [k_dim, k_dim, v_dim], dim = -1)
|
||||
q, k, v = torch.split(mixed_qkv, [self.k_dim, self.k_dim, self.v_dim], dim = -1)
|
||||
q = q.view(bsz, seqlen, -1, self.k_head_dim)
|
||||
k = k.view(bsz, seqlen, -1, self.k_head_dim)
|
||||
v = v.view(bsz, seqlen, -1, self.v_head_dim)
|
||||
|
||||
if self.a_log_f_exp is None:
|
||||
self.a_log_f_exp = -self.a_log.float().exp()
|
||||
g = self.a_log_f_exp * F.softplus(a.float() + self.dt_bias)
|
||||
beta = b.sigmoid()
|
||||
|
||||
# Grouped attn
|
||||
if self.num_v_heads // self.num_k_heads > 1:
|
||||
q = q.repeat_interleave(self.num_v_groups, dim = 2)
|
||||
@@ -398,7 +428,7 @@ class GatedDeltaNet(Module):
|
||||
g = g,
|
||||
beta = beta,
|
||||
initial_state = recurrent_state,
|
||||
output_final_state = True, # cache_params is not None,
|
||||
output_final_state = save_state,
|
||||
use_qk_l2norm_in_kernel = True,
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user