diff --git a/ggml/src/ggml-cuda/delta-net.cu b/ggml/src/ggml-cuda/delta-net.cu index f8e5daee..8291a9f8 100644 --- a/ggml/src/ggml-cuda/delta-net.cu +++ b/ggml/src/ggml-cuda/delta-net.cu @@ -1,6 +1,7 @@ #include "common.cuh" #include "delta-net.cuh" #include +#include // Delta Net Linear Attention Kernel for Qwen3-Next (HEAD_DIM=128) // State layout: [S_v, S_v*H_v, 1, n_seqs] (column-major) @@ -1369,6 +1370,45 @@ __global__ void delta_net_multiblock_f32( } } +enum delta_net_opt_mode : int { + DELTA_NET_OPT_DEFAULT = 0, // keep current dispatch + DELTA_NET_OPT_FP16 = 1, // pre-Blackwell: fp16 recurrent kernel (head_dim=128) + DELTA_NET_OPT_MULTIBLOCK = 2, // pre-Blackwell: multiblock kernel (head_dim=128) + DELTA_NET_OPT_BW_OPT = 3, // Blackwell: padded/bank-conflict-reduced kernel + DELTA_NET_OPT_AUTO = 4, // arch-aware: multiblock (pre-BW), bw-opt (BW) +}; + +static int delta_net_get_opt_mode() { + static const int mode = []() -> int { + const char * env = std::getenv("GGML_CUDA_DELTA_NET_OPT"); + if (env == nullptr || env[0] == '\0') { + return DELTA_NET_OPT_DEFAULT; + } + + if (!strcmp(env, "auto") || !strcmp(env, "AUTO")) { + return DELTA_NET_OPT_AUTO; + } + if (!strcmp(env, "fp16")) { + return DELTA_NET_OPT_FP16; + } + if (!strcmp(env, "multiblock")) { + return DELTA_NET_OPT_MULTIBLOCK; + } + if (!strcmp(env, "blackwell-opt")) { + return DELTA_NET_OPT_BW_OPT; + } + + const int parsed = atoi(env); + if (parsed >= DELTA_NET_OPT_DEFAULT && parsed <= DELTA_NET_OPT_AUTO) { + return parsed; + } + + return DELTA_NET_OPT_DEFAULT; + }(); + + return mode; +} + // Dispatch function // device_id and cc (compute capability) are passed from caller to avoid CUDA runtime API calls static void delta_net_f32_cuda( @@ -1399,6 +1439,7 @@ static void delta_net_f32_cuda( // Shared memory: 9 * head_dim (for Q, K, V, KBeta, VBeta, Out, KCumdecay, VPrime, VNew) // Plus 6 floats for Norm[2], g_val, beta_val, decay, attn_score const size_t smem_size = (9 * head_dim + 6) * sizeof(float); + const int opt_mode = delta_net_get_opt_mode(); // Use templated kernel for common head dimensions, generic for others if (head_dim == 64) { @@ -1410,23 +1451,59 @@ static void delta_net_f32_cuda( // cc is in format MAJOR*100 + MINOR*10 (e.g., 890 for 8.9, 1200 for 12.0) const int sm_major = cc / 100; - if (sm_major >= 12) { - // Blackwell path: single block per head with FULL state in shared memory + constexpr size_t blackwell_state_bytes = 128 * 128 * sizeof(float); // 64 KB + constexpr size_t blackwell_vector_bytes = 9 * 128 * sizeof(float); // 4.5 KB + constexpr size_t blackwell_warp_scratch_bytes = 16 * sizeof(float); // 64 B + constexpr size_t blackwell_smem_size = + blackwell_state_bytes + blackwell_vector_bytes + blackwell_warp_scratch_bytes; + static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch"); + + constexpr size_t blackwell_opt_state_bytes = 128 * 132 * sizeof(float); // padded 128x132 + constexpr size_t blackwell_opt_vector_bytes = 9 * 128 * sizeof(float); + constexpr size_t blackwell_opt_warp_scratch_bytes = 16 * sizeof(float); + constexpr size_t blackwell_opt_smem_size = + blackwell_opt_state_bytes + blackwell_opt_vector_bytes + blackwell_opt_warp_scratch_bytes; + static_assert(blackwell_opt_smem_size == 72256, "Optimized shared memory size mismatch"); + + constexpr int multiblock_cols = 16; + constexpr int multiblock_groups = 128 / multiblock_cols; + constexpr size_t multiblock_smem_floats = + 128 * multiblock_cols + 3 * 128 + 5 * multiblock_cols + 2; + constexpr size_t multiblock_smem_size = multiblock_smem_floats * sizeof(float); + static_assert(multiblock_smem_size == 10056, "Multiblock shared memory size mismatch"); + + constexpr size_t fp16_state_bytes = 128 * 128 * sizeof(half); + constexpr size_t fp16_half_vec_bytes = 3 * 128 * sizeof(half); + constexpr size_t fp16_float_vec_bytes = 6 * 128 * sizeof(float); + constexpr size_t fp16_scalar_bytes = 2 * sizeof(float); + constexpr size_t fp16_smem_size = + fp16_state_bytes + fp16_half_vec_bytes + fp16_float_vec_bytes + fp16_scalar_bytes; + static_assert(fp16_smem_size == 36616, "FP16 shared memory size mismatch"); + + // Keep "auto" conservative on Blackwell (baseline kernel remains default there). + // Explicit modes can still force a different kernel for experiments. + const bool use_bw_opt = + sm_major >= 12 && opt_mode == DELTA_NET_OPT_BW_OPT; + const bool use_multiblock = + opt_mode == DELTA_NET_OPT_MULTIBLOCK || + (sm_major < 12 && opt_mode == DELTA_NET_OPT_AUTO); + const bool use_fp16 = opt_mode == DELTA_NET_OPT_FP16; + + if (use_bw_opt) { const int blackwell_num_blocks = n_seqs * n_heads; const int blackwell_threads = 256; - // Shared memory calculation with explicit breakdown: - // - State matrix: HEAD_DIM × HEAD_DIM × sizeof(float) = 128×128×4 = 65536 bytes (64KB) - // - Vectors (Q,K,V,KBeta,VBeta,KCumdecay,VPrime,VNew,Out): 9 × HEAD_DIM × sizeof(float) = 4608 bytes - // - Warp scratch: 16 × sizeof(float) = 64 bytes - // Total: 65536 + 4608 + 64 = 70208 bytes (~68.6KB) - // __shared__ scalars (decay, beta, etc.) are static, not dynamic - constexpr size_t state_bytes = 128 * 128 * sizeof(float); // 64KB - constexpr size_t vector_bytes = 9 * 128 * sizeof(float); // 4.5KB - constexpr size_t warp_scratch_bytes = 16 * sizeof(float); // 64B - constexpr size_t blackwell_smem_size = state_bytes + vector_bytes + warp_scratch_bytes; + CUDA_CHECK(cudaFuncSetAttribute( + delta_net_blackwell_optimized_f32<128>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + blackwell_opt_smem_size)); - static_assert(blackwell_smem_size == 70208, "Shared memory size mismatch"); + delta_net_blackwell_optimized_f32<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } else if (sm_major >= 12) { + // Blackwell path: single block per head with FULL state in shared memory + const int blackwell_num_blocks = n_seqs * n_heads; + const int blackwell_threads = 256; // A/B comparison mode (set GGML_CUDA_DELTA_NET_AB=1) static const bool ab_mode = []() { @@ -1530,13 +1607,23 @@ static void delta_net_f32_cuda( delta_net_blackwell_f32<128><<>>( q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); } - } else -#endif // !defined(GGML_USE_HIP) - { - // Pre-Blackwell path: Use recurrent kernel + } else if (use_multiblock) { + const int multiblock_num_blocks = n_seqs * n_heads * multiblock_groups; + delta_net_multiblock_f32<128, multiblock_cols><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } else if (use_fp16) { + delta_net_fp16_optimized<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); + } else { + // Baseline pre-Blackwell path delta_net_recurrent_f32<128><<>>( q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); } +#else + // HIP path: keep baseline recurrent implementation + delta_net_recurrent_f32<128><<>>( + q, k, v, g, beta, state_in, dst, n_tokens, n_heads, n_seqs, output_offset, eps); +#endif // !defined(GGML_USE_HIP) } else { delta_net_recurrent_generic_f32<<>>( q, k, v, g, beta, state_in, dst, head_dim, n_tokens, n_heads, n_seqs, output_offset, eps);