mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-02-20 21:24:08 +00:00
cuda: add qwen3next delta-net kernel dispatch override
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
#include "common.cuh"
|
||||
#include "delta-net.cuh"
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
|
||||
// Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128)
|
||||
// State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major)
|
||||
@@ -1369,6 +1370,45 @@ __global__ void delta_net_multiblock_f32(
|
||||
}
|
||||
}
|
||||
|
||||
enum delta_net_opt_mode : int {
|
||||
DELTA_NET_OPT_DEFAULT = 0, // keep current dispatch
|
||||
DELTA_NET_OPT_FP16 = 1, // pre-Blackwell: fp16 recurrent kernel (head_dim=128)
|
||||
DELTA_NET_OPT_MULTIBLOCK = 2, // pre-Blackwell: multiblock kernel (head_dim=128)
|
||||
DELTA_NET_OPT_BW_OPT = 3, // Blackwell: padded/bank-conflict-reduced kernel
|
||||
DELTA_NET_OPT_AUTO = 4, // arch-aware: multiblock (pre-BW), bw-opt (BW)
|
||||
};
|
||||
|
||||
static int delta_net_get_opt_mode() {
|
||||
static const int mode = []() -> int {
|
||||
const char * env = std::getenv("GGML_CUDA_DELTA_NET_OPT");
|
||||
if (env == nullptr || env[0] == '\0') {
|
||||
return DELTA_NET_OPT_DEFAULT;
|
||||
}
|
||||
|
||||
if (!strcmp(env, "auto") || !strcmp(env, "AUTO")) {
|
||||
return DELTA_NET_OPT_AUTO;
|
||||
}
|
||||
if (!strcmp(env, "fp16")) {
|
||||
return DELTA_NET_OPT_FP16;
|
||||
}
|
||||
if (!strcmp(env, "multiblock")) {
|
||||
return DELTA_NET_OPT_MULTIBLOCK;
|
||||
}
|
||||
if (!strcmp(env, "blackwell-opt")) {
|
||||
return DELTA_NET_OPT_BW_OPT;
|
||||
}
|
||||
|
||||
const int parsed = atoi(env);
|
||||
if (parsed >= DELTA_NET_OPT_DEFAULT && parsed <= DELTA_NET_OPT_AUTO) {
|
||||
return parsed;
|
||||
}
|
||||
|
||||
return DELTA_NET_OPT_DEFAULT;
|
||||
}();
|
||||
|
||||
return mode;
|
||||
}
|
||||
|
||||
// Dispatch function
|
||||
// device_id and cc (compute capability) are passed from caller to avoid CUDA runtime API calls
|
||||
static void delta_net_f32_cuda(
|
||||
@@ -1399,6 +1439,7 @@ static void delta_net_f32_cuda(
|
||||
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
|
||||
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
|
||||
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
|
||||
const int opt_mode = delta_net_get_opt_mode();
|
||||
|
||||
// Use templated kernel for common head dimensions, generic for others
|
||||
if (head_dim == 64) {
|
||||
@@ -1410,23 +1451,59 @@ static void delta_net_f32_cuda(
|
||||
// cc is in format MAJOR*100 + MINOR*10 (e.g., 890 for 8.9, 1200 for 12.0)
|
||||
const int sm_major = cc / 100;
|
||||
|
||||
if (sm_major >= 12) {
|
||||
// Blackwell path: single block per head with FULL state in shared memory
|
||||
constexpr size_t blackwell_state_bytes = 128 * 128 * sizeof(float); // 64 KB
|
||||
constexpr size_t blackwell_vector_bytes = 9 * 128 * sizeof(float); // 4.5 KB
|
||||
constexpr size_t blackwell_warp_scratch_bytes = 16 * sizeof(float); // 64 B
|
||||
constexpr size_t blackwell_smem_size =
|
||||
blackwell_state_bytes + blackwell_vector_bytes + blackwell_warp_scratch_bytes;
|
||||
static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch");
|
||||
|
||||
constexpr size_t blackwell_opt_state_bytes = 128 * 132 * sizeof(float); // padded 128x132
|
||||
constexpr size_t blackwell_opt_vector_bytes = 9 * 128 * sizeof(float);
|
||||
constexpr size_t blackwell_opt_warp_scratch_bytes = 16 * sizeof(float);
|
||||
constexpr size_t blackwell_opt_smem_size =
|
||||
blackwell_opt_state_bytes + blackwell_opt_vector_bytes + blackwell_opt_warp_scratch_bytes;
|
||||
static_assert(blackwell_opt_smem_size == 72256, "Optimized shared memory size mismatch");
|
||||
|
||||
constexpr int multiblock_cols = 16;
|
||||
constexpr int multiblock_groups = 128 / multiblock_cols;
|
||||
constexpr size_t multiblock_smem_floats =
|
||||
128 * multiblock_cols + 3 * 128 + 5 * multiblock_cols + 2;
|
||||
constexpr size_t multiblock_smem_size = multiblock_smem_floats * sizeof(float);
|
||||
static_assert(multiblock_smem_size == 10056, "Multiblock shared memory size mismatch");
|
||||
|
||||
constexpr size_t fp16_state_bytes = 128 * 128 * sizeof(half);
|
||||
constexpr size_t fp16_half_vec_bytes = 3 * 128 * sizeof(half);
|
||||
constexpr size_t fp16_float_vec_bytes = 6 * 128 * sizeof(float);
|
||||
constexpr size_t fp16_scalar_bytes = 2 * sizeof(float);
|
||||
constexpr size_t fp16_smem_size =
|
||||
fp16_state_bytes + fp16_half_vec_bytes + fp16_float_vec_bytes + fp16_scalar_bytes;
|
||||
static_assert(fp16_smem_size == 36616, "FP16 shared memory size mismatch");
|
||||
|
||||
// Keep "auto" conservative on Blackwell (baseline kernel remains default there).
|
||||
// Explicit modes can still force a different kernel for experiments.
|
||||
const bool use_bw_opt =
|
||||
sm_major >= 12 && opt_mode == DELTA_NET_OPT_BW_OPT;
|
||||
const bool use_multiblock =
|
||||
opt_mode == DELTA_NET_OPT_MULTIBLOCK ||
|
||||
(sm_major < 12 && opt_mode == DELTA_NET_OPT_AUTO);
|
||||
const bool use_fp16 = opt_mode == DELTA_NET_OPT_FP16;
|
||||
|
||||
if (use_bw_opt) {
|
||||
const int blackwell_num_blocks = n_seqs * n_heads;
|
||||
const int blackwell_threads = 256;
|
||||
|
||||
// Shared memory calculation with explicit breakdown:
|
||||
// - State matrix: HEAD_DIM × HEAD_DIM × sizeof(float) = 128×128×4 = 65536 bytes (64KB)
|
||||
// - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes
|
||||
// - Warp scratch: 16 × sizeof(float) = 64 bytes
|
||||
// Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB)
|
||||
// __shared__ scalars (decay, beta, etc.) are static, not dynamic
|
||||
constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB
|
||||
constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB
|
||||
constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B
|
||||
constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes;
|
||||
CUDA_CHECK(cudaFuncSetAttribute(
|
||||
delta_net_blackwell_optimized_f32<128>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
blackwell_opt_smem_size));
|
||||
|
||||
static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch");
|
||||
delta_net_blackwell_optimized_f32<128><<<blackwell_num_blocks, blackwell_threads, blackwell_opt_smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
} else if (sm_major >= 12) {
|
||||
// Blackwell path: single block per head with FULL state in shared memory
|
||||
const int blackwell_num_blocks = n_seqs * n_heads;
|
||||
const int blackwell_threads = 256;
|
||||
|
||||
// A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1)
|
||||
static const bool ab_mode = []() {
|
||||
@@ -1530,13 +1607,23 @@ static void delta_net_f32_cuda(
|
||||
delta_net_blackwell_f32<128><<<blackwell_num_blocks, blackwell_threads, blackwell_smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
}
|
||||
} else
|
||||
#endif // !defined(GGML_USE_HIP)
|
||||
{
|
||||
// Pre-Blackwell path: Use recurrent kernel
|
||||
} else if (use_multiblock) {
|
||||
const int multiblock_num_blocks = n_seqs * n_heads * multiblock_groups;
|
||||
delta_net_multiblock_f32<128, multiblock_cols><<<multiblock_num_blocks, threads_per_block, multiblock_smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
} else if (use_fp16) {
|
||||
delta_net_fp16_optimized<128><<<num_blocks, threads_per_block, fp16_smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
} else {
|
||||
// Baseline pre-Blackwell path
|
||||
delta_net_recurrent_f32<128><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
}
|
||||
#else
|
||||
// HIP path: keep baseline recurrent implementation
|
||||
delta_net_recurrent_f32<128><<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
#endif // !defined(GGML_USE_HIP)
|
||||
} else {
|
||||
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
|
||||
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);
|
||||
|
||||
Reference in New Issue
Block a user