diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 741ef4062d..f08dd1ca90 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -1062,6 +1062,20 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): else: pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, "f", "f", qscale, mask, "f", "f", sink)) # fmt: skip + + # PER_TOKEN_HEAD: FP8 fine-grained Q/K per-token-per-head + V per-head. + # Currently only wired for fp8bf16 + qr_async + hdim=128 on gfx9 (MI300/MI308). + # The non-paged fmha_fwd qr_async pipeline gained PER_TOKEN_HEAD support in + # block_fmha_pipeline_qr_ks_vs_async.hpp; keep this scope tight until we + # also extend qr / qr_async_trload / V3 pipelines. + if dtype in cls._DT_FP8BF16 and hdim == 128: + for logits, mask, sink in itertools.product( + ["t", "f"], + get_mask_map(mask_impl).keys(), + ["f", "t"], + ): + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip + pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, "no", "f", "f", "per_token_head", mask, "f", "f", sink)) # fmt: skip return pipelines diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index fcb73c48b7..e46f7bfe32 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -305,6 +305,41 @@ struct FmhaFwdKernel const int32_t* seqstart_v_scale_ptr; }; + // PER_TOKEN_HEAD descale layout (non-paged fmha_fwd): + // q_descale: [B, total_q, nhead_q] fp32 (batch+token row stride, head col stride) + // k_descale: [B, total_k, nhead_k] fp32 + // v_descale: [nhead_k] fp32 (per-head only; no batch / token dim) + // For varlen (group) mode the batch dim is collapsed into total_q / total_k and the + // per-batch offset is derived from cu_seqlens * stride_*_descale instead of + // batch_stride_*_descale. + struct FmhaFwdCommonPerTokenHeadKargs : FmhaFwdCommonQScaleKargs + { + // Per-token (row) strides; V is per-head only, so stride_v_descale is unused. + ck_tile::index_t stride_q_descale; + ck_tile::index_t stride_k_descale; + + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + // Unused under PER_TOKEN_HEAD but the qr_async pipeline takes it positionally; + // keep it as a no-op field rather than baking a constant into the call site. + ck_tile::index_t block_scale_size_kv = 128; + }; + + struct FmhaFwdBatchPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; // callers should set this to 0 (V is per-head) + }; + + struct FmhaFwdGroupPerTokenHeadKargs : FmhaFwdCommonPerTokenHeadKargs + { + // group mode resolves per-batch offsets from cu_seqlens * stride_*_descale, + // so no dedicated batch_stride_*_descale is needed here. + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -383,11 +418,16 @@ struct FmhaFwdKernel std::conditional_t< QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, FmhaFwdCommonQScaleKargs, - std::conditional_t>>>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE, + FmhaFwdBatchBlockScaleKargs, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::MX, + FmhaFwdBatchMXKargs, + std::conditional_t>>>>, std::conditional_t>, std::conditional_t> { @@ -414,11 +454,16 @@ struct FmhaFwdKernel std::conditional_t< QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, FmhaFwdCommonQScaleKargs, - std::conditional_t>>>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE, + FmhaFwdGroupBlockScaleKargs, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::MX, + FmhaFwdGroupMXKargs, + std::conditional_t>>>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -612,6 +657,28 @@ struct FmhaFwdKernel kargs.batch_stride_k_descale = batch_stride_k_descale; kargs.batch_stride_v_descale = batch_stride_v_descale; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD (non-paged fmha_fwd) descale layout: + // q_descale: [total_q, nhead_q] fp32 (per-token, per-head) + // k_descale: [total_k, nhead_k] fp32 (per-token, per-head) + // v_descale: [nhead_k] fp32 (per-head only) + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + // per-token row strides (V is per-head scalar so stride_v_descale is unused). + kargs.stride_q_descale = stride_q_descale; + kargs.stride_k_descale = stride_k_descale; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1065,6 +1132,25 @@ struct FmhaFwdKernel kargs.seqstart_v_scale_ptr = reinterpret_cast(seqstart_v_scale_ptr); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD (non-paged fmha_fwd, group mode) descale layout: + // q_descale: [total_q, nhead_q] fp32 (per-token, per-head) + // k_descale: [total_k, nhead_k] fp32 (per-token, per-head) + // v_descale: [nhead_k] fp32 (per-head only) + // Per-batch offsets are derived from query_start/key_start in the kernel, + // so we don't need batch_stride_*_descale here (group mode). + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.stride_q_descale = stride_q_descale; + kargs.stride_k_descale = stride_k_descale; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -1574,6 +1660,16 @@ struct FmhaFwdKernel batch_offset_k_descale = key_start * kargs.stride_k_descale; batch_offset_v_descale = kargs.seqstart_v_scale_ptr[i_batch]; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD descales: Q/K per-token-per-head, V per-head only. + // q_descale: [total_q, nhead_q] -> per-batch offset = query_start * row_stride + // k_descale: [total_k, nhead_k] -> per-batch offset = key_start * row_stride + // v_descale: [nhead_k] -> no per-batch offset (shared across batches) + batch_offset_q_descale = query_start * kargs.stride_q_descale; + batch_offset_k_descale = key_start * kargs.stride_k_descale; + batch_offset_v_descale = 0; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD) @@ -1642,8 +1738,12 @@ struct FmhaFwdKernel static_cast(i_batch) * kargs.batch_stride_randval; } if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE || - QScaleEnum == BlockAttentionQuantScaleEnum::MX) + QScaleEnum == BlockAttentionQuantScaleEnum::MX || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { + // For PER_TOKEN_HEAD V is per-head only (no per-batch dimension); + // callers should set kargs.batch_stride_v_descale = 0 in that case + // so this expression evaluates to 0 here. batch_offset_q_descale = static_cast(i_batch) * kargs.batch_stride_q_descale; batch_offset_k_descale = @@ -2121,6 +2221,66 @@ struct FmhaFwdKernel make_null_tile_window(make_tuple()), sink_value); } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD: Q/K per-(token, head), V per-head (FP8 fine-grained). + // Resolve descale base ptrs to (batch, head); the pipeline indexes + // per-token offsets within head using stride_q/k_descale. + // + // Layout convention (non-paged fmha_fwd): + // q_descale: [total_q, nhead_q] fp32 + // k_descale: [total_k, nhead_k] fp32 + // v_descale: [nhead_k] fp32 + const float* q_descale_ptr = + reinterpret_cast(kargs.q_descale_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_q_descale + + batch_offset_q_descale; + const float* k_descale_ptr = + reinterpret_cast(kargs.k_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_k_descale + + batch_offset_k_descale; + const float* v_descale_ptr = + reinterpret_cast(kargs.v_descale_ptr) + + static_cast(i_nhead_ / kargs.nhead_ratio_qk) * + kargs.nhead_stride_v_descale + + batch_offset_v_descale; + + return FmhaPipeline{}( + q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func - PER_TOKEN_HEAD applies its own per-(row,col) scale + identity{}, // p_compute_element_func + identity{}, // o_acc_element_func - V descale is folded in via 'o_acc += o_acc0 * v_descale' + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout, + k_descale_ptr, + v_descale_ptr, + kargs.block_scale_size_kv, // unused for PER_TOKEN_HEAD + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + make_null_tile_window(make_tuple()), + sink_value, + // PER_TOKEN_HEAD-only: + q_descale_ptr, + kargs.stride_q_descale, + kargs.stride_k_descale); + } else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX) { using QScaleDataType = typename FmhaPipeline::QScaleDataType; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 7b97d01fa4..4f6fca9178 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -204,7 +204,16 @@ struct BlockFmhaPipelineQRKSVSAsync const QScaleDramBlockWindowTmp&, // M0*(K0/kQKScaleGranularity) tile const KScaleDramBlockWindowTmp&, // N0*(K0/kQKScaleGranularity) tile const VScaleDramBlockWindowTmp&, // N1*(K1/kVScaleGranularity) tile - const float sink_v) const + const float sink_v, + // PER_TOKEN_HEAD: per-(token, head) Q/K descales; V is per-head scalar. + // q_descale_ptr / k_descale_ptr / v_descale_ptr are already + // (batch, head)-resolved by the kernel; pipeline indexes + // q_descale_ptr[(q_origin + i) * stride_q_descale_token] + // k_descale_ptr[(k_origin + j) * stride_k_descale_token] + // v_descale = *v_descale_ptr (loaded once per loop iter) + const float* q_descale_ptr = nullptr, + const index_t stride_q_descale_token = 0, + const index_t stride_k_descale_token = 0) const { static_assert( std::is_same_v> && @@ -463,6 +472,87 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN0, kK0>{})); } __builtin_amdgcn_sched_barrier(1); + + // PER_TOKEN_HEAD: dequantize QK with per-row Q descale and per-col K descale. + // s_acc[i,j] *= q_descale[(q_origin + i) * stride_q_descale_token] + // * k_descale[(k_origin + j) * stride_k_descale_token] + // q_descale_ptr / k_descale_ptr are already (batch, head)-resolved by the kernel. + // + // qr_async (no trload) on gfx9 is very tight on SGPRs; folding both row + col + // global loads into a single 2-D sweep over s_acc made the compiler spill the + // K/V buffer-SRDs to VGPRs and produce invalid asm. We instead split the + // dequant into two 1-D sweeps so each pass keeps a single descale pointer + // live and reuses the existing tile-distribution machinery without inflating + // SGPR pressure inside the K/V async-load region. + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD dequant: s_acc[i,j] *= q_descale[i] * k_descale[j]. + // + // qr_async on gfx9 is extremely tight on SGPRs (the K/V buffer-SRDs + // must stay live across iterations of the main K-loop). When threads + // issue scalar global loads from k_descale_ptr inside the per-tile + // 2-D sweep, the compiler runs out of SGPR budget and spills the K + // SRD into VGPRs, producing invalid `buffer_load_dword v_dst, v[..] + // ... lds` asm on gfx942. + // + // To break that dependency we stage both Q-row and K-col descales + // through LDS first (one warp-strided load per descale tile), then + // the per-element sweep is a pure LDS read + FP multiply. This keeps + // the K SRD live-range contained inside the QK gemm, where the + // existing schedule already accounts for it. + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + + const index_t q_row_base = q_origin.at(number<0>{}); + + // Mirror the V / randval path's sink-aware col-base derivation so we + // never touch k_dram_block_window.get_window_origin() inside this + // SGPR-tight region (that read leaks the K window state and was + // independently observed to push the SRD into VGPRs). + const bool in_sink_phase = (num_sink_loop > i_total_loops); + const index_t k_col_base = + in_sink_phase + ? (kv_load_start + i_total_loops * kN0) + : (seqlen_k_start + (i_total_loops - num_sink_loop) * kN0); + + __builtin_amdgcn_sched_barrier(0); + + // LDS staging tiles. Allocated per-block; kept small (kM0 + kN0 fp32). + __shared__ float lds_q_descale[kM0]; + __shared__ float lds_k_descale[kN0]; + + const index_t tid_in_block = + static_cast(threadIdx.x + threadIdx.y * blockDim.x + + threadIdx.z * blockDim.x * blockDim.y); + const index_t threads_per_block = + static_cast(blockDim.x * blockDim.y * blockDim.z); + + // Q-row descales (kM0 entries). + for(index_t off = tid_in_block; off < kM0; off += threads_per_block) + { + lds_q_descale[off] = + q_descale_ptr[(q_row_base + off) * stride_q_descale_token]; + } + // K-col descales (kN0 entries). + for(index_t off = tid_in_block; off < kN0; off += threads_per_block) + { + lds_k_descale[off] = + k_descale_ptr[(k_col_base + off) * stride_k_descale_token]; + } + __builtin_amdgcn_s_barrier(); + + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + const index_t i = tile_idx.at(number<0>{}); + const index_t j = tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + s_acc(i_j_idx) *= lds_q_descale[i] * lds_k_descale[j]; + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + // dequant auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() { if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) @@ -785,12 +875,19 @@ struct BlockFmhaPipelineQRKSVSAsync const index_t kv_idx = (kv_load_start + i_total_loops * kN0) / block_scale_size_kv; v_descale = v_descale_ptr[kv_idx]; } + else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) + { + // PER_TOKEN_HEAD: V scale is per-head only; v_descale_ptr is already + // (batch, head)-resolved by the kernel, so load the single scalar. + v_descale = *v_descale_ptr; + } // STAGE 3, KV gemm auto o_acc0 = decltype(o_acc){}; clear_tile(o_acc0); auto& o_acc_ = [&o_acc0, &o_acc]() -> auto& { - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { return o_acc0; } @@ -879,7 +976,8 @@ struct BlockFmhaPipelineQRKSVSAsync sequence<(LdsSeq.at(number{}) + 1) * kN1, kK1>{})); } - if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE || + QScaleEnum == BlockAttentionQuantScaleEnum::PER_TOKEN_HEAD) { tile_elementwise_inout( [&v_descale](auto& o, auto& o0) { o += o0 * v_descale; }, o_acc, o_acc0);