cuda: add qwen3next delta-net kernel dispatch override

This commit is contained in:
yurko
2026-02-08 14:38:30 -08:00
parent b5c9554a88
commit eef360a85f

View File

@@ -1,6 +1,7 @@
#include "common.cuh"
#include "delta-net.cuh"
#include <cstdlib>
#include <cstring>
// Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128)
// State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major)
@@ -1369,6 +1370,45 @@ __global__ void delta_net_multiblock_f32(
}
}
enum delta_net_opt_mode : int {
DELTA_NET_OPT_DEFAULT = 0, // keep current dispatch
DELTA_NET_OPT_FP16 = 1, // pre-Blackwell: fp16 recurrent kernel (head_dim=128)
DELTA_NET_OPT_MULTIBLOCK = 2, // pre-Blackwell: multiblock kernel (head_dim=128)
DELTA_NET_OPT_BW_OPT = 3, // Blackwell: padded/bank-conflict-reduced kernel
DELTA_NET_OPT_AUTO = 4, // arch-aware: multiblock (pre-BW), bw-opt (BW)
};
static int delta_net_get_opt_mode() {
static const int mode = []() -> int {
const char * env = std::getenv("GGML_CUDA_DELTA_NET_OPT");
if (env == nullptr || env[0] == '\0') {
return DELTA_NET_OPT_DEFAULT;
}
if (!strcmp(env, "auto") || !strcmp(env, "AUTO")) {
return DELTA_NET_OPT_AUTO;
}
if (!strcmp(env, "fp16")) {
return DELTA_NET_OPT_FP16;
}
if (!strcmp(env, "multiblock")) {
return DELTA_NET_OPT_MULTIBLOCK;
}
if (!strcmp(env, "blackwell-opt")) {
return DELTA_NET_OPT_BW_OPT;
}
const int parsed = atoi(env);
if (parsed >= DELTA_NET_OPT_DEFAULT && parsed <= DELTA_NET_OPT_AUTO) {
return parsed;
}
return DELTA_NET_OPT_DEFAULT;
}();
return mode;
}
// Dispatch function
// device_id and cc (compute capability) are passed from caller to avoid CUDA runtime API calls
static void delta_net_f32_cuda(
@@ -1399,6 +1439,7 @@ static void delta_net_f32_cuda(
// Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew)
// Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score
const size_t smem_size = (9 * head_dim + 6) * sizeof(float);
const int opt_mode = delta_net_get_opt_mode();
// Use templated kernel for common head dimensions, generic for others
if (head_dim == 64) {
@@ -1410,23 +1451,59 @@ static void delta_net_f32_cuda(
// cc is in format MAJOR*100 + MINOR*10 (e.g., 890 for 8.9, 1200 for 12.0)
const int sm_major = cc / 100;
if (sm_major >= 12) {
// Blackwell path: single block per head with FULL state in shared memory
constexpr size_t blackwell_state_bytes = 128 * 128 * sizeof(float); // 64 KB
constexpr size_t blackwell_vector_bytes = 9 * 128 * sizeof(float); // 4.5 KB
constexpr size_t blackwell_warp_scratch_bytes = 16 * sizeof(float); // 64 B
constexpr size_t blackwell_smem_size =
blackwell_state_bytes + blackwell_vector_bytes + blackwell_warp_scratch_bytes;
static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch");
constexpr size_t blackwell_opt_state_bytes = 128 * 132 * sizeof(float); // padded 128x132
constexpr size_t blackwell_opt_vector_bytes = 9 * 128 * sizeof(float);
constexpr size_t blackwell_opt_warp_scratch_bytes = 16 * sizeof(float);
constexpr size_t blackwell_opt_smem_size =
blackwell_opt_state_bytes + blackwell_opt_vector_bytes + blackwell_opt_warp_scratch_bytes;
static_assert(blackwell_opt_smem_size == 72256, "Optimized shared memory size mismatch");
constexpr int multiblock_cols = 16;
constexpr int multiblock_groups = 128 / multiblock_cols;
constexpr size_t multiblock_smem_floats =
128 * multiblock_cols + 3 * 128 + 5 * multiblock_cols + 2;
constexpr size_t multiblock_smem_size = multiblock_smem_floats * sizeof(float);
static_assert(multiblock_smem_size == 10056, "Multiblock shared memory size mismatch");
constexpr size_t fp16_state_bytes = 128 * 128 * sizeof(half);
constexpr size_t fp16_half_vec_bytes = 3 * 128 * sizeof(half);
constexpr size_t fp16_float_vec_bytes = 6 * 128 * sizeof(float);
constexpr size_t fp16_scalar_bytes = 2 * sizeof(float);
constexpr size_t fp16_smem_size =
fp16_state_bytes + fp16_half_vec_bytes + fp16_float_vec_bytes + fp16_scalar_bytes;
static_assert(fp16_smem_size == 36616, "FP16 shared memory size mismatch");
// Keep "auto" conservative on Blackwell (baseline kernel remains default there).
// Explicit modes can still force a different kernel for experiments.
const bool use_bw_opt =
sm_major >= 12 && opt_mode == DELTA_NET_OPT_BW_OPT;
const bool use_multiblock =
opt_mode == DELTA_NET_OPT_MULTIBLOCK ||
(sm_major < 12 && opt_mode == DELTA_NET_OPT_AUTO);
const bool use_fp16 = opt_mode == DELTA_NET_OPT_FP16;
if (use_bw_opt) {
const int blackwell_num_blocks = n_seqs * n_heads;
const int blackwell_threads = 256;
// Shared memory calculation with explicit breakdown:
// - State matrix: HEAD_DIM × HEAD_DIM × sizeof(float) = 128×128×4 = 65536 bytes (64KB)
// - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes
// - Warp scratch: 16 × sizeof(float) = 64 bytes
// Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB)
// __shared__ scalars (decay, beta, etc.) are static, not dynamic
constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB
constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB
constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B
constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes;
CUDA_CHECK(cudaFuncSetAttribute(
delta_net_blackwell_optimized_f32<128>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
blackwell_opt_smem_size));
static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch");
delta_net_blackwell_optimized_f32<128><<<blackwell_num_blocks, blackwell_threads, blackwell_opt_smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
} else if (sm_major >= 12) {
// Blackwell path: single block per head with FULL state in shared memory
const int blackwell_num_blocks = n_seqs * n_heads;
const int blackwell_threads = 256;
// A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1)
static const bool ab_mode = []() {
@@ -1530,13 +1607,23 @@ static void delta_net_f32_cuda(
delta_net_blackwell_f32<128><<<blackwell_num_blocks, blackwell_threads, blackwell_smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
}
} else
#endif // !defined(GGML_USE_HIP)
{
// Pre-Blackwell path: Use recurrent kernel
} else if (use_multiblock) {
const int multiblock_num_blocks = n_seqs * n_heads * multiblock_groups;
delta_net_multiblock_f32<128, multiblock_cols><<<multiblock_num_blocks, threads_per_block, multiblock_smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
} else if (use_fp16) {
delta_net_fp16_optimized<128><<<num_blocks, threads_per_block, fp16_smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
} else {
// Baseline pre-Blackwell path
delta_net_recurrent_f32<128><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
}
#else
// HIP path: keep baseline recurrent implementation
delta_net_recurrent_f32<128><<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps);
#endif // !defined(GGML_USE_HIP)
} else {
delta_net_recurrent_generic_f32<<<num_blocks, threads_per_block, smem_size, stream>>>(
q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);