diff --git a/example/ck_tile/50_sparse_attn/README.md b/example/ck_tile/50_sparse_attn/README.md new file mode 100644 index 0000000000..c7191c8e82 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/README.md @@ -0,0 +1,45 @@ +# Sparge Attention (Composable Kernel) + +A Composable Kernel port of [SpargeAttn](https://github.com/thu-ml/SpargeAttn) for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via `-pipeline=vsa` (default, faster) and `-pipeline=jenga` (async K/V load variant). + +## Status vs Upstream + +Implemented: +- per-block mean-pool, cosine similarity, pooled QK +- top-k / `cdfthreshd` block selection, BlockMap LUT +- sparse FMHA (both `vsa` and `jenga` backends) +- per-head `topk` / `simthreshd1` / `cdfthreshd` + +Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)): +- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53)) +- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338)) +- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355)) +- **pv_threshold per-Q-tile skip in attn kernel** — pure perf, ~5–15% on the dominant attention slice ([spas_sage_attn/core.py:L265](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L265)) +- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345)) +- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371)) + +## Performance + +At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches **1.78× FMHA throughput at topk=0.4** and **5.04× at topk=0.1**, and stays above 1.0× across the full topk range. + +![Speedup vs sparsity](docs/speedup_vs_sparsity.png) + +*Speedup vs FMHA, b=2 h=32 s=16384 d=128 fp16. Shape chosen to match Fig. 10 of the SpargeAttn paper ([arXiv:2502.18137](https://arxiv.org/abs/2502.18137); Mochi-1, 22K context, head_dim=128); s=16384 is the closest grid point. Gray-outlined points have >30% inter-rep spread.* + +![Kernel breakdown](docs/kernel_breakdown.png) + +*BlockMap (`_pre`) stacked on attention (`_attn`), b=2 h=32 d=128 fp16 topk=0.4. BlockMap is roughly 17% of total at s=16384.* + +## Usage + +```bash +ninja tile_example_sparge +./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001 +``` + +Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k. + +## References + +- [SpargeAttn upstream](https://github.com/thu-ml/SpargeAttn) (pinned to [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)) +- [Paper — Zhang et al., arXiv:2502.18137](https://arxiv.org/abs/2502.18137) diff --git a/example/ck_tile/50_sparse_attn/docs/kernel_breakdown.png b/example/ck_tile/50_sparse_attn/docs/kernel_breakdown.png new file mode 100644 index 0000000000..8704334155 Binary files /dev/null and b/example/ck_tile/50_sparse_attn/docs/kernel_breakdown.png differ diff --git a/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py b/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py new file mode 100644 index 0000000000..95a13d5f65 --- /dev/null +++ b/example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py @@ -0,0 +1,258 @@ +#!/usr/bin/env python3 +"""Plot sparge perf charts from full_grid.csv. + +Re-run with different fixed (b, h, s, dtype, topk) by editing the constants below. +No GPU / no srun / no rebuild — pure matplotlib from CSV. +""" +import os +import sys +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np + +# ---------------------------------------------------------------------- +# Tunable constants — edit these to regenerate for a different point. +# ---------------------------------------------------------------------- +CSV_PATH = "/home/AMD/ginolu12/gino_tmp/full_grid.csv" +OUT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# Chart 1 — speedup vs topk for one fixed (b, h, s, dtype) +CHART1_B = 2 +CHART1_H = 32 +CHART1_S = 16384 +CHART1_DTYPE = "fp16" +CHART1_HEAD_DIM = 128 # for title only + +# Chart 2 — kernel breakdown across s for fixed (b, h, dtype, topk) +CHART2_B = 2 +CHART2_H = 32 +CHART2_DTYPE = "fp16" +CHART2_TOPK = 0.4 +CHART2_S_LIST = [2048, 4096, 8192, 16384] +CHART2_HEAD_DIM = 128 # for title only + +DPI = 140 + +# ---------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------- +def is_fail(note): + if not isinstance(note, str): + return False + return "FAIL" in note + +def is_high_spread(note): + if not isinstance(note, str): + return False + return "HIGH_SPREAD" in note + +def load_data(): + df = pd.read_csv(CSV_PATH) + return df + +# ---------------------------------------------------------------------- +# Chart 1 +# ---------------------------------------------------------------------- +def plot_chart1(df, out_path): + sel = df[ + (df["b"] == CHART1_B) + & (df["h"] == CHART1_H) + & (df["s"] == CHART1_S) + & (df["dtype"] == CHART1_DTYPE) + ].copy() + sel = sel.sort_values("topk").reset_index(drop=True) + + if sel.empty: + print(f"[chart1] WARNING: no rows for b={CHART1_B} h={CHART1_H} s={CHART1_S} dtype={CHART1_DTYPE}") + return [], 0 + + # Drop fully failed rows but keep partial-fail rows; we'll mask per-series. + # Convert numeric columns + for col in ["sparge_jenga", "sparge_vsa", "sparse_jenga", "sparse_vsa", "fmha_us"]: + sel[col] = pd.to_numeric(sel[col], errors="coerce") + + fmha = sel["fmha_us"] + + # Compute speedups; rows with FAIL on a given column will have NaN already. + series = { + "sparge_vsa": fmha / sel["sparge_vsa"], + "sparge_jenga": fmha / sel["sparge_jenga"], + "sparse_vsa": fmha / sel["sparse_vsa"], + "sparse_jenga": fmha / sel["sparse_jenga"], + } + + style = { + "sparge_vsa": {"color": "#1f77b4", "marker": "o", "lw": 2.0}, + "sparge_jenga": {"color": "#ff7f0e", "marker": "s", "lw": 2.0}, + "sparse_vsa": {"color": "#2ca02c", "marker": "^", "lw": 1.5, "ls": "--"}, + "sparse_jenga": {"color": "#d62728", "marker": "v", "lw": 1.5, "ls": "--"}, + } + + fig, ax = plt.subplots(figsize=(8.5, 5.5), dpi=DPI) + + x = sel["topk"].to_numpy() + + # HIGH_SPREAD overlay first (under main markers) + hs_mask = sel["note"].apply(is_high_spread) + high_spread_cells = [] + if hs_mask.any(): + for _, row in sel[hs_mask].iterrows(): + high_spread_cells.append((row["topk"], row["max_spread_pct"])) + # gray ring underneath every series's data point at that x + for label, sp in series.items(): + xs_hs = x[hs_mask.to_numpy()] + ys_hs = sp[hs_mask.to_numpy()].to_numpy() + ax.scatter(xs_hs, ys_hs, s=180, facecolors="none", + edgecolors="gray", linewidths=1.5, zorder=2) + + for label, sp in series.items(): + st = style[label] + ax.plot(x, sp.to_numpy(), label=label, + color=st["color"], marker=st["marker"], + linewidth=st["lw"], linestyle=st.get("ls", "-"), + markersize=7, zorder=3) + + ax.axhline(1.0, color="black", linestyle=":", linewidth=1.2, label="fmha (baseline)", zorder=1) + + ax.set_xlabel("topk (kept fraction)") + ax.set_ylabel("speedup vs FMHA dense (×)") + ax.set_title( + f"Speedup vs FMHA " + f"(b={CHART1_B} h={CHART1_H} s={CHART1_S} d={CHART1_HEAD_DIM} {CHART1_DTYPE})" + ) + ax.grid(True, which="both", linestyle=":", alpha=0.6) + ax.set_xticks(np.arange(0.1, 0.71, 0.1)) + ax.legend(loc="best", framealpha=0.9) + + # Footnote about HIGH_SPREAD overlay + if high_spread_cells: + ax.text(0.01, -0.16, + "Gray rings: HIGH_SPREAD cells (high run-to-run variance)", + transform=ax.transAxes, fontsize=8, color="gray") + + fig.tight_layout() + fig.savefig(out_path, dpi=DPI, bbox_inches="tight") + plt.close(fig) + return high_spread_cells, os.path.getsize(out_path) + + +# ---------------------------------------------------------------------- +# Chart 2 +# ---------------------------------------------------------------------- +def plot_chart2(df, out_path): + sel = df[ + (df["b"] == CHART2_B) + & (df["h"] == CHART2_H) + & (df["dtype"] == CHART2_DTYPE) + & (np.isclose(df["topk"], CHART2_TOPK)) + & (df["s"].isin(CHART2_S_LIST)) + ].copy() + sel = sel.sort_values("s").reset_index(drop=True) + + if sel.empty: + print(f"[chart2] WARNING: no rows for b={CHART2_B} h={CHART2_H} dtype={CHART2_DTYPE} topk={CHART2_TOPK}") + return 0 + + for col in ["sparge_jenga_pre", "sparge_jenga_attn", + "sparge_vsa_pre", "sparge_vsa_attn", "fmha_us"]: + sel[col] = pd.to_numeric(sel[col], errors="coerce") + + s_vals = sel["s"].to_numpy() + n = len(s_vals) + idx = np.arange(n, dtype=float) + + width = 0.35 + offset = width / 2 + 0.02 + + fig, ax = plt.subplots(figsize=(9.0, 5.8), dpi=DPI) + + # Jenga bars (left of group) + jenga_pre = sel["sparge_jenga_pre"].to_numpy() + jenga_attn = sel["sparge_jenga_attn"].to_numpy() + vsa_pre = sel["sparge_vsa_pre"].to_numpy() + vsa_attn = sel["sparge_vsa_attn"].to_numpy() + fmha_vals = sel["fmha_us"].to_numpy() + + color_jenga_pre = "#fdbf6f" # light orange + color_jenga_attn = "#ff7f0e" # orange + color_vsa_pre = "#a6cee3" # light blue + color_vsa_attn = "#1f77b4" # blue + + bj_pre = ax.bar(idx - offset, jenga_pre, width, + color=color_jenga_pre, edgecolor="black", linewidth=0.6, + label="sparge_jenga _pre (BlockMap)") + bj_at = ax.bar(idx - offset, jenga_attn, width, bottom=jenga_pre, + color=color_jenga_attn, edgecolor="black", linewidth=0.6, + label="sparge_jenga _attn") + bv_pre = ax.bar(idx + offset, vsa_pre, width, + color=color_vsa_pre, edgecolor="black", linewidth=0.6, + label="sparge_vsa _pre (BlockMap)") + bv_at = ax.bar(idx + offset, vsa_attn, width, bottom=vsa_pre, + color=color_vsa_attn, edgecolor="black", linewidth=0.6, + label="sparge_vsa _attn") + + # Add total labels on top of each stack + totals_jenga = jenga_pre + jenga_attn + totals_vsa = vsa_pre + vsa_attn + for i in range(n): + ax.text(idx[i] - offset, totals_jenga[i], f"{totals_jenga[i]:.0f}", + ha="center", va="bottom", fontsize=8) + ax.text(idx[i] + offset, totals_vsa[i], f"{totals_vsa[i]:.0f}", + ha="center", va="bottom", fontsize=8) + + # FMHA reference: short horizontal dashed segment per group + seg_half = 0.40 + fmha_label_done = False + for i in range(n): + ax.hlines(fmha_vals[i], idx[i] - seg_half, idx[i] + seg_half, + colors="black", linestyles="dashed", linewidth=1.2, + label="fmha dense (reference)" if not fmha_label_done else None, + zorder=5) + ax.text(idx[i] + seg_half + 0.02, fmha_vals[i], + f"fmha {fmha_vals[i]:.0f}", fontsize=7, va="center", color="black") + fmha_label_done = True + + ax.set_xticks(idx) + ax.set_xticklabels([f"s={s}" for s in s_vals.astype(int)]) + ax.set_xlabel("sequence length (s)") + ax.set_ylabel("kernel time (µs)") + ax.set_title( + f"Sparge kernel time breakdown " + f"(b={CHART2_B} h={CHART2_H} d={CHART2_HEAD_DIM} {CHART2_DTYPE}, topk={CHART2_TOPK})" + ) + ax.grid(True, axis="y", linestyle=":", alpha=0.6) + ax.legend(loc="upper left", framealpha=0.9, fontsize=9) + + # log-y is too aggressive — leave linear; bars will just be tall. + fig.tight_layout() + fig.savefig(out_path, dpi=DPI, bbox_inches="tight") + plt.close(fig) + return os.path.getsize(out_path) + + +# ---------------------------------------------------------------------- +# Main +# ---------------------------------------------------------------------- +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + df = load_data() + + chart1_path = os.path.join(OUT_DIR, "speedup_vs_sparsity.png") + chart2_path = os.path.join(OUT_DIR, "kernel_breakdown.png") + + hs_cells, size1 = plot_chart1(df, chart1_path) + size2 = plot_chart2(df, chart2_path) + + print(f"Wrote {chart1_path} ({size1} bytes)") + print(f"Wrote {chart2_path} ({size2} bytes)") + + if hs_cells: + print("HIGH_SPREAD cells in chart-1 selection:") + for topk, pct in hs_cells: + print(f" topk={topk} max_spread_pct={pct}") + else: + print("No HIGH_SPREAD cells in chart-1 selection.") + + +if __name__ == "__main__": + main() diff --git a/example/ck_tile/50_sparse_attn/docs/speedup_vs_sparsity.png b/example/ck_tile/50_sparse_attn/docs/speedup_vs_sparsity.png new file mode 100644 index 0000000000..9a2f053b0b Binary files /dev/null and b/example/ck_tile/50_sparse_attn/docs/speedup_vs_sparsity.png differ diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp index a2df5bac56..3cc674f181 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_inst.cpp @@ -5,6 +5,9 @@ #include "sparge_blockmap_trek.hpp" #include "ck_tile/ops/fmha/block/variants.hpp" +#include +#include +#include #include // ============================================================================ @@ -61,6 +64,9 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem; using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel; +using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline; +using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel; + // ============================================================================ // bf16: D=128, kM0=64, kN0=128 // ============================================================================ @@ -112,6 +118,78 @@ using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem; using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel; +using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline; +using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel; + +// ============================================================================ +// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized +// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged. +// ============================================================================ + +namespace { + +struct KStatsWorkspace +{ + void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32 + void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8 + size_t pooled_k_bytes = 0; + size_t sim_k_bytes = 0; + + void ensure(int batch, int nhead_k, int N_k, int D) + { + const size_t need_p = static_cast(batch) * nhead_k * N_k * D * sizeof(float); + const size_t need_s = static_cast(batch) * nhead_k * N_k * sizeof(uint8_t); + if(need_p > pooled_k_bytes) + { + if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev); + (void)hipMalloc(&pooled_k_dev, need_p); + pooled_k_bytes = need_p; + } + if(need_s > sim_k_bytes) + { + if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev); + (void)hipMalloc(&sim_k_dev, need_s); + sim_k_bytes = need_s; + } + } +}; + +KStatsWorkspace& g_kstats_ws() +{ + static KStatsWorkspace ws; + return ws; +} + +template +void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s) +{ + const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0); + const int D = BlockMapKernel::D; + auto& ws = g_kstats_ws(); + ws.ensure(args.batch, args.nhead_k, N_k, D); + + // Stage 1: K stats + { + auto [kargs, grids] = + sparge_kstats_create_kargs_and_grids(args, ws.pooled_k_dev, ws.sim_k_dev); + const dim3 blocks = KStatsKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu; + ck_tile::make_kernel(KStatsKernel{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + } + // Stage 2: block_map (reads ws) + { + auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids( + args, ws.pooled_k_dev, ws.sim_k_dev); + const dim3 blocks = BlockMapKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu; + ck_tile::make_kernel(BlockMapKernel{}, grids, blocks, 0, kargs)( + ck_tile::stream_config{s.stream_id_}); + } +} + +} // namespace + // ============================================================================ // Dispatch // ============================================================================ @@ -122,26 +200,20 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits, { if(traits.data_type == "fp16" && traits.hdim_q == 128) { - using k_ = bmap_fp16_kernel; if(s.log_level_ > 0) std::cout << ", sparge_blockmap_fp16_d128" << std::flush; - auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) { + launch_kstats_then_blockmap(args, s_); + }); } if(traits.data_type == "bf16" && traits.hdim_q == 128) { - using k_ = bmap_bf16_kernel; if(s.log_level_ > 0) std::cout << ", sparge_blockmap_bf16_d128" << std::flush; - auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)); + return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) { + launch_kstats_then_blockmap(args, s_); + }); } if(s.log_level_ > 0) @@ -160,23 +232,13 @@ void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits, { if(traits.data_type == "fp16" && traits.hdim_q == 128) { - using k_ = bmap_fp16_kernel; - auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( - ck_tile::stream_config{s.stream_id_}); + launch_kstats_then_blockmap(args, s); return; } if(traits.data_type == "bf16" && traits.hdim_q == 128) { - using k_ = bmap_bf16_kernel; - auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids(args); - const dim3 blocks = k_::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; - ck_tile::make_kernel(k_{}, grids, blocks, 0, kargs)( - ck_tile::stream_config{s.stream_id_}); + launch_kstats_then_blockmap(args, s); return; } diff --git a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp index 6eaeb9ea77..92c32d29e8 100644 --- a/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp +++ b/example/ck_tile/50_sparse_attn/sparge_blockmap_trek.hpp @@ -8,7 +8,9 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp" #include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp" +#include "ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp" #include "fmha_fwd_trek.hpp" @@ -45,6 +47,15 @@ struct sparge_blockmap_args void* block_map_ptr; void* lut_ptr; void* valid_block_num_ptr; + + // R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar. + // Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized + // by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the + // first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for + // MQA/GQA aligns to upstream tuned ckpt layout. + const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1]) + const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats + const float* topk_per_head_ptr = nullptr; // size = nhead_q floats }; struct sparge_blockmap_traits @@ -57,7 +68,9 @@ struct sparge_blockmap_traits // Create kernel args and grid dimensions // ============================================================================ template -auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) +auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args, + const void* pooled_k_ws_ptr, + const void* sim_k_ws_ptr) { assert(args.nhead_q % args.nhead_k == 0); auto kargs = BlockMapKernel::MakeKargs(args.q_ptr, @@ -79,12 +92,38 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args) args.scale, args.block_map_ptr, args.lut_ptr, - args.valid_block_num_ptr); + args.valid_block_num_ptr, + pooled_k_ws_ptr, + sim_k_ws_ptr, + args.topk_per_head_ptr, + args.cdfthreshd_per_head_ptr); dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q); return ck_tile::make_tuple(kargs, grids); } +template +auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args, + void* pooled_k_ws_ptr, + void* sim_k_ws_ptr) +{ + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = KStatsKernel::MakeKargs(args.k_ptr, + args.seqlen_k, + args.hdim_q, + args.nhead_k, + args.stride_k, + args.nhead_stride_k, + args.batch_stride_k, + args.simthreshd1, + pooled_k_ws_ptr, + sim_k_ws_ptr, + args.simthreshd1_per_head_ptr); + + dim3 grids = KStatsKernel::GridSize(args.batch, args.nhead_k, args.seqlen_k); + return ck_tile::make_tuple(kargs, grids); +} + // ============================================================================ // Hand-written template instantiation dispatch // ============================================================================ diff --git a/example/ck_tile/50_sparse_attn/test_sparge.cpp b/example/ck_tile/50_sparse_attn/test_sparge.cpp index 81a49ca006..4c97a10d0f 100644 --- a/example/ck_tile/50_sparse_attn/test_sparge.cpp +++ b/example/ck_tile/50_sparse_attn/test_sparge.cpp @@ -105,7 +105,10 @@ auto create_args(int argc, char* argv[]) .insert("seed", "42", "random seed") .insert("warmup", "5", "warmup iterations") .insert("repeat", "20", "benchmark iterations") - .insert("kname", "0", "print kernel name"); + .insert("kname", "0", "print kernel name") + .insert("perhead", "0", + "R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test " + "(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -135,6 +138,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser) int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); int kname = arg_parser.get_int("kname"); + int perhead = arg_parser.get_int("perhead"); if(nhead_k < 0) nhead_k = nhead; if(seqlen_k < 0) seqlen_k = seqlen_q; @@ -231,6 +235,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser) bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr; bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr; + // R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q] + // to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)). + // K-side kernel reads only the first nhead_k entries via [hk]. + ck_tile::DeviceMem topk_per_head_dev(static_cast(nhead) * sizeof(float)); + ck_tile::DeviceMem sim1_per_head_dev(static_cast(nhead) * sizeof(float)); + ck_tile::DeviceMem cdf_per_head_dev (static_cast(nhead) * sizeof(float)); + if(perhead != 0) + { + std::vector topk_h(nhead); + std::vector sim1_h(nhead); + std::vector cdf_h (nhead); + for(int h = 0; h < nhead; ++h) + { + // small per-head jitter around scalar topk so sparsity differs by head + const float jitter = 0.5f * (static_cast(h - nhead / 2) / nhead); + topk_h[h] = topk * (1.0f + jitter); + sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1]) + cdf_h[h] = cdfthreshd; + } + topk_per_head_dev.ToDevice(topk_h.data()); + sim1_per_head_dev.ToDevice(sim1_h.data()); + cdf_per_head_dev .ToDevice(cdf_h.data()); + bmap_args.topk_per_head_ptr = static_cast(topk_per_head_dev.GetDeviceBuffer()); + bmap_args.simthreshd1_per_head_ptr = static_cast(sim1_per_head_dev.GetDeviceBuffer()); + bmap_args.cdfthreshd_per_head_ptr = static_cast(cdf_per_head_dev.GetDeviceBuffer()); + } + // ---- build attention args ---- ck_tile::stream_config stream_cfg; stream_cfg.stream_id_ = nullptr; diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp index ca177abf23..62b5b3591c 100644 --- a/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp @@ -52,7 +52,20 @@ struct SpargeBlockMapKernel void* lut_ptr; void* valid_block_num_ptr; + // R20 K-stat workspace from Kernel A + const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32 + const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8 + index_t N_k; + + // R21A Phase 4: optional per-head topk (size = nhead_q floats). + // nullptr => use scalar `topk` for all heads. + const float* topk_per_head; + + // R21B: optional per-head cdfthreshd (size = nhead_q floats). + // nullptr => use scalar `cdfthreshd` for all heads. + // Only consulted on topk<=0 path; bench currently always uses topk path. + const float* cdfthreshd_per_head; }; CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr, @@ -74,7 +87,11 @@ struct SpargeBlockMapKernel float scale, void* block_map_ptr, void* lut_ptr, - void* valid_block_num_ptr) + void* valid_block_num_ptr, + const void* pooled_k_ws_ptr, + const void* sim_k_ws_ptr, + const float* topk_per_head = nullptr, + const float* cdfthreshd_per_head = nullptr) { const index_t N_k = integer_divide_ceil(seqlen_k, kN0); return Kargs{q_ptr, @@ -97,7 +114,11 @@ struct SpargeBlockMapKernel block_map_ptr, lut_ptr, valid_block_num_ptr, - N_k}; + pooled_k_ws_ptr, + sim_k_ws_ptr, + N_k, + topk_per_head, + cdfthreshd_per_head}; } CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q) @@ -174,6 +195,21 @@ struct SpargeBlockMapKernel // Shared memory __shared__ char smem[Pipeline::GetSmemSize()]; + // R20 K-stat workspace: pre-offset for this (b, hk). + const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk; + const index_t khead_off = (b * nhead_k + hk) * N_k; + const auto* pooled_k_ws = + reinterpret_cast(kargs.pooled_k_ws_ptr) + khead_off * D; + const auto* sim_k_ws = + reinterpret_cast(kargs.sim_k_ws_ptr) + khead_off; + + // R21A Phase 4: per-head topk if provided, else scalar broadcast. + const float topk_eff = + (kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk; + // R21B: per-head cdfthreshd if provided, else scalar broadcast. + const float cdfthreshd_eff = + (kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd; + Pipeline{}(q_window, k_window, kargs.seqlen_q, @@ -182,12 +218,14 @@ struct SpargeBlockMapKernel N_k, kargs.nhead_ratio_qk, kargs.simthreshd1, - kargs.cdfthreshd, - kargs.topk, + cdfthreshd_eff, + topk_eff, kargs.scale, bmap_ptr, lut_out, valid_out, + pooled_k_ws, + sim_k_ws, static_cast(smem)); } }; diff --git a/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp new file mode 100644 index 0000000000..3ce494f870 --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp @@ -0,0 +1,136 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include + +namespace ck_tile { + +// Kernel A wrapper: grid (N_k, nhead_k, batch). Each work-group precomputes +// K-block stats (pooled_k_mean[D], sim_k) for one (b, hk, kb) into a workspace +// that Kernel B (block_map) reads instead of recomputing per Q-block. +template +struct SpargeKStatsKernel +{ + using Pipeline = remove_cvref_t; + + static constexpr index_t kBlockSize = Pipeline::kBlockSize; + static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu; + + using QDataType = typename Pipeline::QDataType; + using KDataType = typename Pipeline::KDataType; + + static constexpr index_t kN0 = Pipeline::kN0; + static constexpr index_t D = Pipeline::D; + + static constexpr index_t kAlignment = 16 / sizeof(KDataType); + + struct Kargs + { + const void* k_ptr; + + index_t seqlen_k; + index_t hdim_q; + index_t nhead_k; + + index_t stride_k; + index_t nhead_stride_k; + index_t batch_stride_k; + + float simthreshd1; + + void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32 + void* sim_k_ptr; // [batch, nhead_k, N_k] uint8 + + index_t N_k; + + // R21A Phase 4 + R21B fix: optional per-head simthreshd1. + // Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract + // (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first + // nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`. + const float* simthreshd1_per_head; + }; + + CK_TILE_HOST static constexpr auto MakeKargs(const void* k_ptr, + index_t seqlen_k, + index_t hdim_q, + index_t nhead_k, + index_t stride_k, + index_t nhead_stride_k, + index_t batch_stride_k, + float simthreshd1, + void* pooled_k_ptr, + void* sim_k_ptr, + const float* simthreshd1_per_head = nullptr) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return Kargs{k_ptr, + seqlen_k, + hdim_q, + nhead_k, + stride_k, + nhead_stride_k, + batch_stride_k, + simthreshd1, + pooled_k_ptr, + sim_k_ptr, + N_k, + simthreshd1_per_head}; + } + + CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_k, index_t seqlen_k) + { + const index_t N_k = integer_divide_ceil(seqlen_k, kN0); + return dim3(N_k, nhead_k, batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const index_t kb = static_cast(blockIdx.x); + const index_t hk = static_cast(blockIdx.y); + const index_t b = static_cast(blockIdx.z); + + const auto* k_base = reinterpret_cast(kargs.k_ptr) + + b * kargs.batch_stride_k + hk * kargs.nhead_stride_k + + kb * kN0 * kargs.stride_k; + + const auto k_dram_naive = make_naive_tensor_view( + k_base, + make_tuple(kargs.seqlen_k - kb * kN0, D), + make_tuple(kargs.stride_k, 1), + number{}, + number<1>{}); + const auto k_dram = pad_tensor_view( + k_dram_naive, make_tuple(number{}, number{}), sequence{}); + + auto k_window = make_tile_window(k_dram, + make_tuple(number{}, number{}), + {0, 0}, + Pipeline::MakeKBlockDistribution()); + + const index_t N_k = kargs.N_k; + const index_t khead_off = (b * kargs.nhead_k + hk) * N_k; + auto* pooled_k_out = reinterpret_cast(kargs.pooled_k_ptr) + (khead_off + kb) * D; + auto* sim_k_out = reinterpret_cast(kargs.sim_k_ptr) + (khead_off + kb); + + __shared__ char smem[Pipeline::GetSmemSize()]; + + // R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast. + const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr) + ? kargs.simthreshd1_per_head[hk] + : kargs.simthreshd1; + + Pipeline{}(k_window, + kargs.seqlen_k, + kb, + simthreshd1_eff, + pooled_k_out, + sim_k_out, + static_cast(smem)); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp index 222e73c60e..25e3b964e9 100644 --- a/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp @@ -32,14 +32,22 @@ struct SpargeBlockMapPipeline static constexpr index_t kMaxKBlocks = 1024; // LDS layout (non-overlapping, all used simultaneously in Phase 2): - // [0 .. kReduceBytes) cross-warp reduction scratch - // [kScoreOffset ..) scores[N_k] - // [kBmapOffset ..) block_map[N_k] - // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) - static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float); - static constexpr index_t kScoreOffset = kReduceBytes; - static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); - static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); + // [0 .. kReduceBytes) cross-warp reduction scratch slab 0 + // [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1 + // (Round 8 b1: ping-pong for K-loop double buffer) + // [kScoreOffset ..) scores[N_k] + // [kBmapOffset ..) block_map[N_k] + // [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats) + // B2.v3 column-stride pad: replace k_idx*KPerThread with k_idx*(KPerThread+1) + // to break the 4-way intra-warp bank conflict. New per-warp slab size: + // KThreads * (KPerThread + 1) floats. + static constexpr index_t kColPaddedStride = KPerThread + 1; + static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride; + static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float); + static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // Round 8 b1: 2 slabs + static constexpr index_t kScoreOffset = kReduceTotalBytes; + static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float); + static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { @@ -98,6 +106,12 @@ struct SpargeBlockMapPipeline } // Cross-warp LDS reduction for column sums. + // Round 13f: templated TrailingSync flag. When false, the trailing __syncthreads() + // is dropped — only safe when the next access targets a *different* slab and the + // intervening work does not read smem_reduce. Used at the slab_b call in Phase 2 + // K-loop, where the next iter's first cross-warp reduce writes to slab_a (different + // address) and is preceded by its own leading sync. + template CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread], float* __restrict__ smem_reduce) { @@ -107,17 +121,21 @@ struct SpargeBlockMapPipeline const index_t k_idx = lane_id % KThreads; const index_t m_idx = lane_id / KThreads; + // B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8, + // changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0, + // lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0. if(m_idx == 0) for(index_t k = 0; k < KPerThread; ++k) - smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k]; + smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k]; __syncthreads(); for(index_t k = 0; k < KPerThread; ++k) col_acc[k] = 0.f; for(index_t w = 0; w < NumWarps; ++w) for(index_t k = 0; k < KPerThread; ++k) - col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k]; - __syncthreads(); + col_acc[k] += smem_reduce[w * kPerWarpFloats + k_idx * kColPaddedStride + k]; + if constexpr(TrailingSync) + __syncthreads(); } // Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx. @@ -162,7 +180,8 @@ struct SpargeBlockMapPipeline for(index_t m = 0; m < SeqPerThread; ++m) { - float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f; + // Round 12: hardware fast rsqrt (v_rsq_f32, ~1 ULP) replaces sw sqrt+rcp. + float inv_norm = (row_norms[m] > 0.f) ? rsqrtf(row_norms[m]) : 0.f; index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx; if(gsq < actual_seq) for(index_t k = 0; k < KPerThread; ++k) @@ -230,9 +249,9 @@ struct SpargeBlockMapPipeline // ====================================================================== template CK_TILE_DEVICE void operator()(const QWindowType& q_window_in, - const KWindowType& k_window_in, + const KWindowType& /*k_window_in*/, index_t seqlen_q, - index_t seqlen_k, + index_t /*seqlen_k*/, index_t qb, index_t N_k, index_t /*nhead_ratio_qk*/, @@ -243,11 +262,15 @@ struct SpargeBlockMapPipeline uint8_t* block_map_ptr, int32_t* lut_ptr, int32_t* valid_block_num_ptr, + const float* __restrict__ pooled_k_ws_ptr, + const uint8_t* __restrict__ sim_k_ws_ptr, void* smem_ptr) const { const index_t tid = static_cast(threadIdx.x); - auto* smem_float = reinterpret_cast(smem_ptr); + // R20: K-loop no longer reduces, only Phase 1 uses smem_float0. + // smem_float1 slab is allocated for layout compat but unused. + auto* smem_float0 = reinterpret_cast(smem_ptr); auto* smem_scores = reinterpret_cast(reinterpret_cast(smem_ptr) + kScoreOffset); auto* smem_bmap = @@ -271,16 +294,22 @@ struct SpargeBlockMapPipeline row_reduce_sq_norm(q_data, psq, bs_q); // 1b. Column sum -> mean + // Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab + // (smem_float0) and has its own leading __syncthreads() before reading. + // pooled_q_mean is register-only between reduces. float pooled_q_mean[KPerThread]; column_reduce_thread_and_warp(q_data, pooled_q_mean); - column_reduce_cross_warp(pooled_q_mean, smem_float); + column_reduce_cross_warp(pooled_q_mean, smem_float0); for(index_t k = 0; k < KPerThread; ++k) pooled_q_mean[k] *= inv_bs_q; // 1c. Normalised sum_hat + // Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in + // K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab, + // its leading __syncthreads() covers the WAR. sum_hat register-only here. float sum_hat[KPerThread]; column_reduce_normalised(q_data, psq, sum_hat, bs_q); - column_reduce_cross_warp(sum_hat, smem_float); + column_reduce_cross_warp(sum_hat, smem_float0); // 1d. sim_q = ||sum_hat||^2 / bs_q^2 float sh_sq = 0.f; @@ -319,49 +348,34 @@ struct SpargeBlockMapPipeline smem_bmap[i] = 0; __syncthreads(); - auto k_window = k_window_in; + // R20: K-stats precomputed by Kernel A. Each thread loads its own + // KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single + // byte. No K-tile load, no cross-warp reduce in the K-loop. + const index_t lane_id_kb = tid % WarpSize; + const index_t k_idx_kb = lane_id_kb % KThreads; for(index_t kb = 0; kb < N_k; ++kb) { - const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); - const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; - - auto k_tile = load_tile(k_window); - - float k_data[NPerThread * KPerThread]; - tile_to_float(k_tile, k_data); - - // K mean + const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread; float pooled_k_mean[KPerThread]; - column_reduce_thread_and_warp(k_data, pooled_k_mean); - column_reduce_cross_warp(pooled_k_mean, smem_float); for(index_t k = 0; k < KPerThread; ++k) - pooled_k_mean[k] *= inv_bs_k; + pooled_k_mean[k] = p_kb[k]; - // dot(pooled_q_mean, pooled_k_mean) float dot = 0.f; for(index_t k = 0; k < KPerThread; ++k) dot += pooled_q_mean[k] * pooled_k_mean[k]; dot = reduce_across_k(dot); - // K L2 norms + normalised sum_hat - float k_psq[NPerThread]; - row_reduce_sq_norm(k_data, k_psq, bs_k); - - float k_sum_hat[KPerThread]; - column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); - column_reduce_cross_warp(k_sum_hat, smem_float); - - // sim_k - float ksh_sq = 0.f; - for(index_t k = 0; k < KPerThread; ++k) - ksh_sq += k_sum_hat[k] * k_sum_hat[k]; - ksh_sq = reduce_across_k(ksh_sq); - const float denom_k = static_cast(bs_k) * static_cast(bs_k); - const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + const bool sim_k = (sim_k_ws_ptr[kb] != 0); if(tid == 0) { + // INVARIANT (mirrors SpargeAttn ref utils.py:175-180): + // ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1) + // AND have score = -inf so Phase 3 selection (topk / cdf) does NOT + // pick them again (would double-count toward topk budget). + // Both writes MUST stay together. Any Phase 3 selection rewrite + // (e.g. iterative argmax → bitonic sort) must keep the -inf write. if(!sim_k) { smem_bmap[kb] = 1; @@ -372,10 +386,8 @@ struct SpargeBlockMapPipeline smem_scores[kb] = dot * scale; } } - __syncthreads(); - - move_tile_window(k_window, {kN0, 0}); } + __syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores // ================================================================== // Phase 3: Softmax + Selection @@ -399,15 +411,24 @@ struct SpargeBlockMapPipeline } const float sum_exp = block_reduce_sum(lsum, smem_small); - // normalise - const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; - for(index_t i = tid; i < N_k; i += kBlockSize) - smem_scores[i] *= inv_sum; - __syncthreads(); + // Round 13i: argmax is invariant under positive scaling (inv_sum > 0). When + // topk > 0 we never read normalised values for cdfthreshd, so skip the + // normalise pass entirely (saves N_k LDS writes + 1 __syncthreads). The + // cdfthreshd path (topk <= 0) still requires normalised scores so the + // accumulator `cumulative_prob` matches probabilities. + const bool topk_active = (topk > 0.f); + const float inv_sum = + (!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f; + if(!topk_active) + { + for(index_t i = tid; i < N_k; i += kBlockSize) + smem_scores[i] *= inv_sum; + __syncthreads(); + } // Selection: iterative argmax index_t num_to_select = - (topk > 0.f) + topk_active ? max(static_cast(1), static_cast(topk * static_cast(N_k))) : N_k; @@ -448,6 +469,11 @@ struct SpargeBlockMapPipeline } __syncthreads(); + // Round 13g: collapse 2 syncs/round into 1. tid==0 computes the global + // winner AND writes the sentinel (smem_bmap=1, smem_scores=-1) in the same + // critical section, gated by bv>0. All threads then read smem_small[0] for + // the early break / cumulative_prob accumulation. Saves 1 __syncthreads per + // round (~32 syncs @ N_k=64 topk=0.5). if(tid == 0) { float bv = smem_small[0]; @@ -462,24 +488,22 @@ struct SpargeBlockMapPipeline bi = wi; } } + // Write sentinel into bmap/scores in the same critical section. + // Guarded by bv > 0 so we never poison a valid score with -1. + if(bv > 0.f) + { + smem_bmap[bi] = 1; + smem_scores[bi] = -1.f; + } smem_small[0] = bv; - smem_small[1] = bit_cast(static_cast(bi)); } __syncthreads(); - float g_val = smem_small[0]; - index_t g_idx = bit_cast(smem_small[1]); + float g_val = smem_small[0]; if(g_val <= 0.f) break; - if(tid == 0) - { - smem_bmap[g_idx] = 1; - smem_scores[g_idx] = -1.f; - } - __syncthreads(); - if(topk > 0.f) { if(round + 1 >= num_to_select) diff --git a/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp new file mode 100644 index 0000000000..1cb96d716a --- /dev/null +++ b/include/ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp @@ -0,0 +1,110 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp" + +namespace ck_tile { + +// Kernel A of the K-stat precompute split: one work-group per (b, hk, kb) +// computes pooled_k_mean and sim_k for that K-block once. Kernel B then reads +// from the workspace instead of recomputing per Q-block. +template +struct SpargeKStatsPipeline +{ + using Problem = remove_cvref_t; + using Base = SpargeBlockMapPipeline; + using QDataType = typename Base::QDataType; + using KDataType = typename Base::KDataType; + + static constexpr index_t kBlockSize = Base::kBlockSize; + static constexpr index_t kM0 = Base::kM0; + static constexpr index_t kN0 = Base::kN0; + static constexpr index_t D = Base::D; + static constexpr index_t NumWarps = Base::NumWarps; + static constexpr index_t WarpSize = Base::WarpSize; + + static constexpr index_t KPerThread = Base::KPerThread; + static constexpr index_t KThreads = Base::KThreads; + static constexpr index_t SeqThreadPerWarp = Base::SeqThreadPerWarp; + static constexpr index_t NPerThread = Base::NPerThread; + + static constexpr index_t kBlockPerCu = 1; + + static constexpr index_t kColPaddedStride = Base::kColPaddedStride; + static constexpr index_t kPerWarpFloats = Base::kPerWarpFloats; + static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kReduceBytes; } + + CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution() + { + return Base::MakeKBlockDistribution(); + } + + // operator(): one work-group, one K-block. Writes D fp32 + 1 uint8 to workspace. + template + CK_TILE_DEVICE void operator()(const KWindowType& k_window, + index_t seqlen_k, + index_t kb, + float simthreshd1, + float* __restrict__ pooled_k_out, // D floats + uint8_t* __restrict__ sim_k_out, // 1 byte + void* smem_ptr) const + { + const index_t tid = static_cast(threadIdx.x); + auto* smem_reduce = reinterpret_cast(smem_ptr); + + const index_t bs_k = min(static_cast(kN0), seqlen_k - kb * kN0); + const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast(bs_k)) : 0.f; + + auto k_tile = load_tile(k_window); + + float k_data[NPerThread * KPerThread]; + Base::template tile_to_float(k_tile, k_data); + + const index_t warp_id = tid / WarpSize; + const index_t lane_id = tid % WarpSize; + const index_t k_idx = lane_id % KThreads; + const index_t m_idx = lane_id / KThreads; + + // pooled_k_mean: column sum then cross-warp reduce. + // R21A: drop trailing sync (next cross_warp_reduce has its own leading sync). + float pooled_k_mean[KPerThread]; + Base::template column_reduce_thread_and_warp(k_data, pooled_k_mean); + Base::template column_reduce_cross_warp(pooled_k_mean, smem_reduce); + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_mean[k] *= inv_bs_k; + + // R21A: write pooled_k_mean to global early so its register liveness ends here, + // freeing VGPR before k_sum_hat becomes live. + if(warp_id == 0 && m_idx == 0) + { + for(index_t k = 0; k < KPerThread; ++k) + pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k]; + } + + // K row L2 norms + normalised column sum (k_sum_hat) + float k_psq[NPerThread]; + Base::template row_reduce_sq_norm(k_data, k_psq, bs_k); + + float k_sum_hat[KPerThread]; + Base::template column_reduce_normalised(k_data, k_psq, k_sum_hat, bs_k); + // R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write). + Base::template column_reduce_cross_warp(k_sum_hat, smem_reduce); + + // sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1 + float ksh_sq = 0.f; + for(index_t k = 0; k < KPerThread; ++k) + ksh_sq += k_sum_hat[k] * k_sum_hat[k]; + ksh_sq = Base::reduce_across_k(ksh_sq); + const float denom_k = static_cast(bs_k) * static_cast(bs_k); + const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1); + + if(tid == 0) + *sim_k_out = sim_k ? static_cast(1) : static_cast(0); + } +}; + +} // namespace ck_tile