sparse_attn: split KStats kernel, add README + perf charts

- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
  per-block K stats workspace consumed by Kernel B), removing redundant
  K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
  to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
  + reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
Gino Lu
2026-05-05 03:13:24 -04:00
parent eca3cb3e0a
commit b00e5449c8
11 changed files with 839 additions and 96 deletions

View File

@@ -0,0 +1,45 @@
# Sparge Attention (Composable Kernel)
A Composable Kernel port of [SpargeAttn](https://github.com/thu-ml/SpargeAttn) for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via `-pipeline=vsa` (default, faster) and `-pipeline=jenga` (async K/V load variant).
## Status vs Upstream
Implemented:
- per-block mean-pool, cosine similarity, pooled QK
- top-k / `cdfthreshd` block selection, BlockMap LUT
- sparse FMHA (both `vsa` and `jenga` backends)
- per-head `topk` / `simthreshd1` / `cdfthreshd`
Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)):
- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53))
- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
- **pv_threshold per-Q-tile skip in attn kernel** — pure perf, ~515% on the dominant attention slice ([spas_sage_attn/core.py:L265](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L265))
- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345))
- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371))
## Performance
At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches **1.78× FMHA throughput at topk=0.4** and **5.04× at topk=0.1**, and stays above 1.0× across the full topk range.
![Speedup vs sparsity](docs/speedup_vs_sparsity.png)
*Speedup vs FMHA, b=2 h=32 s=16384 d=128 fp16. Shape chosen to match Fig. 10 of the SpargeAttn paper ([arXiv:2502.18137](https://arxiv.org/abs/2502.18137); Mochi-1, 22K context, head_dim=128); s=16384 is the closest grid point. Gray-outlined points have >30% inter-rep spread.*
![Kernel breakdown](docs/kernel_breakdown.png)
*BlockMap (`_pre`) stacked on attention (`_attn`), b=2 h=32 d=128 fp16 topk=0.4. BlockMap is roughly 17% of total at s=16384.*
## Usage
```bash
ninja tile_example_sparge
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001
```
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.
## References
- [SpargeAttn upstream](https://github.com/thu-ml/SpargeAttn) (pinned to [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a))
- [Paper — Zhang et al., arXiv:2502.18137](https://arxiv.org/abs/2502.18137)

Binary file not shown.

After

Width:  |  Height:  |  Size: 83 KiB

View File

@@ -0,0 +1,258 @@
#!/usr/bin/env python3
"""Plot sparge perf charts from full_grid.csv.
Re-run with different fixed (b, h, s, dtype, topk) by editing the constants below.
No GPU / no srun / no rebuild — pure matplotlib from CSV.
"""
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
# ----------------------------------------------------------------------
# Tunable constants — edit these to regenerate for a different point.
# ----------------------------------------------------------------------
CSV_PATH = "/home/AMD/ginolu12/gino_tmp/full_grid.csv"
OUT_DIR = os.path.dirname(os.path.abspath(__file__))
# Chart 1 — speedup vs topk for one fixed (b, h, s, dtype)
CHART1_B = 2
CHART1_H = 32
CHART1_S = 16384
CHART1_DTYPE = "fp16"
CHART1_HEAD_DIM = 128 # for title only
# Chart 2 — kernel breakdown across s for fixed (b, h, dtype, topk)
CHART2_B = 2
CHART2_H = 32
CHART2_DTYPE = "fp16"
CHART2_TOPK = 0.4
CHART2_S_LIST = [2048, 4096, 8192, 16384]
CHART2_HEAD_DIM = 128 # for title only
DPI = 140
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def is_fail(note):
if not isinstance(note, str):
return False
return "FAIL" in note
def is_high_spread(note):
if not isinstance(note, str):
return False
return "HIGH_SPREAD" in note
def load_data():
df = pd.read_csv(CSV_PATH)
return df
# ----------------------------------------------------------------------
# Chart 1
# ----------------------------------------------------------------------
def plot_chart1(df, out_path):
sel = df[
(df["b"] == CHART1_B)
& (df["h"] == CHART1_H)
& (df["s"] == CHART1_S)
& (df["dtype"] == CHART1_DTYPE)
].copy()
sel = sel.sort_values("topk").reset_index(drop=True)
if sel.empty:
print(f"[chart1] WARNING: no rows for b={CHART1_B} h={CHART1_H} s={CHART1_S} dtype={CHART1_DTYPE}")
return [], 0
# Drop fully failed rows but keep partial-fail rows; we'll mask per-series.
# Convert numeric columns
for col in ["sparge_jenga", "sparge_vsa", "sparse_jenga", "sparse_vsa", "fmha_us"]:
sel[col] = pd.to_numeric(sel[col], errors="coerce")
fmha = sel["fmha_us"]
# Compute speedups; rows with FAIL on a given column will have NaN already.
series = {
"sparge_vsa": fmha / sel["sparge_vsa"],
"sparge_jenga": fmha / sel["sparge_jenga"],
"sparse_vsa": fmha / sel["sparse_vsa"],
"sparse_jenga": fmha / sel["sparse_jenga"],
}
style = {
"sparge_vsa": {"color": "#1f77b4", "marker": "o", "lw": 2.0},
"sparge_jenga": {"color": "#ff7f0e", "marker": "s", "lw": 2.0},
"sparse_vsa": {"color": "#2ca02c", "marker": "^", "lw": 1.5, "ls": "--"},
"sparse_jenga": {"color": "#d62728", "marker": "v", "lw": 1.5, "ls": "--"},
}
fig, ax = plt.subplots(figsize=(8.5, 5.5), dpi=DPI)
x = sel["topk"].to_numpy()
# HIGH_SPREAD overlay first (under main markers)
hs_mask = sel["note"].apply(is_high_spread)
high_spread_cells = []
if hs_mask.any():
for _, row in sel[hs_mask].iterrows():
high_spread_cells.append((row["topk"], row["max_spread_pct"]))
# gray ring underneath every series's data point at that x
for label, sp in series.items():
xs_hs = x[hs_mask.to_numpy()]
ys_hs = sp[hs_mask.to_numpy()].to_numpy()
ax.scatter(xs_hs, ys_hs, s=180, facecolors="none",
edgecolors="gray", linewidths=1.5, zorder=2)
for label, sp in series.items():
st = style[label]
ax.plot(x, sp.to_numpy(), label=label,
color=st["color"], marker=st["marker"],
linewidth=st["lw"], linestyle=st.get("ls", "-"),
markersize=7, zorder=3)
ax.axhline(1.0, color="black", linestyle=":", linewidth=1.2, label="fmha (baseline)", zorder=1)
ax.set_xlabel("topk (kept fraction)")
ax.set_ylabel("speedup vs FMHA dense (×)")
ax.set_title(
f"Speedup vs FMHA "
f"(b={CHART1_B} h={CHART1_H} s={CHART1_S} d={CHART1_HEAD_DIM} {CHART1_DTYPE})"
)
ax.grid(True, which="both", linestyle=":", alpha=0.6)
ax.set_xticks(np.arange(0.1, 0.71, 0.1))
ax.legend(loc="best", framealpha=0.9)
# Footnote about HIGH_SPREAD overlay
if high_spread_cells:
ax.text(0.01, -0.16,
"Gray rings: HIGH_SPREAD cells (high run-to-run variance)",
transform=ax.transAxes, fontsize=8, color="gray")
fig.tight_layout()
fig.savefig(out_path, dpi=DPI, bbox_inches="tight")
plt.close(fig)
return high_spread_cells, os.path.getsize(out_path)
# ----------------------------------------------------------------------
# Chart 2
# ----------------------------------------------------------------------
def plot_chart2(df, out_path):
sel = df[
(df["b"] == CHART2_B)
& (df["h"] == CHART2_H)
& (df["dtype"] == CHART2_DTYPE)
& (np.isclose(df["topk"], CHART2_TOPK))
& (df["s"].isin(CHART2_S_LIST))
].copy()
sel = sel.sort_values("s").reset_index(drop=True)
if sel.empty:
print(f"[chart2] WARNING: no rows for b={CHART2_B} h={CHART2_H} dtype={CHART2_DTYPE} topk={CHART2_TOPK}")
return 0
for col in ["sparge_jenga_pre", "sparge_jenga_attn",
"sparge_vsa_pre", "sparge_vsa_attn", "fmha_us"]:
sel[col] = pd.to_numeric(sel[col], errors="coerce")
s_vals = sel["s"].to_numpy()
n = len(s_vals)
idx = np.arange(n, dtype=float)
width = 0.35
offset = width / 2 + 0.02
fig, ax = plt.subplots(figsize=(9.0, 5.8), dpi=DPI)
# Jenga bars (left of group)
jenga_pre = sel["sparge_jenga_pre"].to_numpy()
jenga_attn = sel["sparge_jenga_attn"].to_numpy()
vsa_pre = sel["sparge_vsa_pre"].to_numpy()
vsa_attn = sel["sparge_vsa_attn"].to_numpy()
fmha_vals = sel["fmha_us"].to_numpy()
color_jenga_pre = "#fdbf6f" # light orange
color_jenga_attn = "#ff7f0e" # orange
color_vsa_pre = "#a6cee3" # light blue
color_vsa_attn = "#1f77b4" # blue
bj_pre = ax.bar(idx - offset, jenga_pre, width,
color=color_jenga_pre, edgecolor="black", linewidth=0.6,
label="sparge_jenga _pre (BlockMap)")
bj_at = ax.bar(idx - offset, jenga_attn, width, bottom=jenga_pre,
color=color_jenga_attn, edgecolor="black", linewidth=0.6,
label="sparge_jenga _attn")
bv_pre = ax.bar(idx + offset, vsa_pre, width,
color=color_vsa_pre, edgecolor="black", linewidth=0.6,
label="sparge_vsa _pre (BlockMap)")
bv_at = ax.bar(idx + offset, vsa_attn, width, bottom=vsa_pre,
color=color_vsa_attn, edgecolor="black", linewidth=0.6,
label="sparge_vsa _attn")
# Add total labels on top of each stack
totals_jenga = jenga_pre + jenga_attn
totals_vsa = vsa_pre + vsa_attn
for i in range(n):
ax.text(idx[i] - offset, totals_jenga[i], f"{totals_jenga[i]:.0f}",
ha="center", va="bottom", fontsize=8)
ax.text(idx[i] + offset, totals_vsa[i], f"{totals_vsa[i]:.0f}",
ha="center", va="bottom", fontsize=8)
# FMHA reference: short horizontal dashed segment per group
seg_half = 0.40
fmha_label_done = False
for i in range(n):
ax.hlines(fmha_vals[i], idx[i] - seg_half, idx[i] + seg_half,
colors="black", linestyles="dashed", linewidth=1.2,
label="fmha dense (reference)" if not fmha_label_done else None,
zorder=5)
ax.text(idx[i] + seg_half + 0.02, fmha_vals[i],
f"fmha {fmha_vals[i]:.0f}", fontsize=7, va="center", color="black")
fmha_label_done = True
ax.set_xticks(idx)
ax.set_xticklabels([f"s={s}" for s in s_vals.astype(int)])
ax.set_xlabel("sequence length (s)")
ax.set_ylabel("kernel time (µs)")
ax.set_title(
f"Sparge kernel time breakdown "
f"(b={CHART2_B} h={CHART2_H} d={CHART2_HEAD_DIM} {CHART2_DTYPE}, topk={CHART2_TOPK})"
)
ax.grid(True, axis="y", linestyle=":", alpha=0.6)
ax.legend(loc="upper left", framealpha=0.9, fontsize=9)
# log-y is too aggressive — leave linear; bars will just be tall.
fig.tight_layout()
fig.savefig(out_path, dpi=DPI, bbox_inches="tight")
plt.close(fig)
return os.path.getsize(out_path)
# ----------------------------------------------------------------------
# Main
# ----------------------------------------------------------------------
def main():
os.makedirs(OUT_DIR, exist_ok=True)
df = load_data()
chart1_path = os.path.join(OUT_DIR, "speedup_vs_sparsity.png")
chart2_path = os.path.join(OUT_DIR, "kernel_breakdown.png")
hs_cells, size1 = plot_chart1(df, chart1_path)
size2 = plot_chart2(df, chart2_path)
print(f"Wrote {chart1_path} ({size1} bytes)")
print(f"Wrote {chart2_path} ({size2} bytes)")
if hs_cells:
print("HIGH_SPREAD cells in chart-1 selection:")
for topk, pct in hs_cells:
print(f" topk={topk} max_spread_pct={pct}")
else:
print("No HIGH_SPREAD cells in chart-1 selection.")
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

View File

@@ -5,6 +5,9 @@
#include "sparge_blockmap_trek.hpp"
#include "ck_tile/ops/fmha/block/variants.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
#include <cstdint>
#include <iostream>
// ============================================================================
@@ -61,6 +64,9 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, //
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_fp16_problem>;
using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel<kstats_fp16_pipeline>;
// ============================================================================
// bf16: D=128, kM0=64, kN0=128
// ============================================================================
@@ -112,6 +118,78 @@ using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, //
using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_bf16_problem>;
using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel<bmap_bf16_pipeline>;
using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_bf16_problem>;
using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel<kstats_bf16_pipeline>;
// ============================================================================
// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized
// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged.
// ============================================================================
namespace {
struct KStatsWorkspace
{
void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32
void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8
size_t pooled_k_bytes = 0;
size_t sim_k_bytes = 0;
void ensure(int batch, int nhead_k, int N_k, int D)
{
const size_t need_p = static_cast<size_t>(batch) * nhead_k * N_k * D * sizeof(float);
const size_t need_s = static_cast<size_t>(batch) * nhead_k * N_k * sizeof(uint8_t);
if(need_p > pooled_k_bytes)
{
if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev);
(void)hipMalloc(&pooled_k_dev, need_p);
pooled_k_bytes = need_p;
}
if(need_s > sim_k_bytes)
{
if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev);
(void)hipMalloc(&sim_k_dev, need_s);
sim_k_bytes = need_s;
}
}
};
KStatsWorkspace& g_kstats_ws()
{
static KStatsWorkspace ws;
return ws;
}
template <typename KStatsKernel, typename BlockMapKernel>
void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s)
{
const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0);
const int D = BlockMapKernel::D;
auto& ws = g_kstats_ws();
ws.ensure(args.batch, args.nhead_k, N_k, D);
// Stage 1: K stats
{
auto [kargs, grids] =
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, ws.pooled_k_dev, ws.sim_k_dev);
const dim3 blocks = KStatsKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
}
// Stage 2: block_map (reads ws)
{
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(
args, ws.pooled_k_dev, ws.sim_k_dev);
const dim3 blocks = BlockMapKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
}
}
} // namespace
// ============================================================================
// Dispatch
// ============================================================================
@@ -122,26 +200,20 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits,
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
using k_ = bmap_fp16_kernel;
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s_);
});
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
using k_ = bmap_bf16_kernel;
if(s.log_level_ > 0)
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s_);
});
}
if(s.log_level_ > 0)
@@ -160,23 +232,13 @@ void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
{
if(traits.data_type == "fp16" && traits.hdim_q == 128)
{
using k_ = bmap_fp16_kernel;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s);
return;
}
if(traits.data_type == "bf16" && traits.hdim_q == 128)
{
using k_ = bmap_bf16_kernel;
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s);
return;
}

View File

@@ -8,7 +8,9 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp"
#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp"
#include "ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp"
#include "fmha_fwd_trek.hpp"
@@ -45,6 +47,15 @@ struct sparge_blockmap_args
void* block_map_ptr;
void* lut_ptr;
void* valid_block_num_ptr;
// R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar.
// Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized
// by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the
// first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for
// MQA/GQA aligns to upstream tuned ckpt layout.
const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1])
const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats
const float* topk_per_head_ptr = nullptr; // size = nhead_q floats
};
struct sparge_blockmap_traits
@@ -57,7 +68,9 @@ struct sparge_blockmap_traits
// Create kernel args and grid dimensions
// ============================================================================
template <typename BlockMapKernel>
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args,
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = BlockMapKernel::MakeKargs(args.q_ptr,
@@ -79,12 +92,38 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
args.scale,
args.block_map_ptr,
args.lut_ptr,
args.valid_block_num_ptr);
args.valid_block_num_ptr,
pooled_k_ws_ptr,
sim_k_ws_ptr,
args.topk_per_head_ptr,
args.cdfthreshd_per_head_ptr);
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
template <typename KStatsKernel>
auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args,
void* pooled_k_ws_ptr,
void* sim_k_ws_ptr)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = KStatsKernel::MakeKargs(args.k_ptr,
args.seqlen_k,
args.hdim_q,
args.nhead_k,
args.stride_k,
args.nhead_stride_k,
args.batch_stride_k,
args.simthreshd1,
pooled_k_ws_ptr,
sim_k_ws_ptr,
args.simthreshd1_per_head_ptr);
dim3 grids = KStatsKernel::GridSize(args.batch, args.nhead_k, args.seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
// ============================================================================
// Hand-written template instantiation dispatch
// ============================================================================

View File

@@ -105,7 +105,10 @@ auto create_args(int argc, char* argv[])
.insert("seed", "42", "random seed")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations")
.insert("kname", "0", "print kernel name");
.insert("kname", "0", "print kernel name")
.insert("perhead", "0",
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -135,6 +138,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int kname = arg_parser.get_int("kname");
int perhead = arg_parser.get_int("perhead");
if(nhead_k < 0) nhead_k = nhead;
if(seqlen_k < 0) seqlen_k = seqlen_q;
@@ -231,6 +235,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
// K-side kernel reads only the first nhead_k entries via [hk].
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
if(perhead != 0)
{
std::vector<float> topk_h(nhead);
std::vector<float> sim1_h(nhead);
std::vector<float> cdf_h (nhead);
for(int h = 0; h < nhead; ++h)
{
// small per-head jitter around scalar topk so sparsity differs by head
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
topk_h[h] = topk * (1.0f + jitter);
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
cdf_h[h] = cdfthreshd;
}
topk_per_head_dev.ToDevice(topk_h.data());
sim1_per_head_dev.ToDevice(sim1_h.data());
cdf_per_head_dev .ToDevice(cdf_h.data());
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
}
// ---- build attention args ----
ck_tile::stream_config stream_cfg;
stream_cfg.stream_id_ = nullptr;

View File

@@ -52,7 +52,20 @@ struct SpargeBlockMapKernel
void* lut_ptr;
void* valid_block_num_ptr;
// R20 K-stat workspace from Kernel A
const void* pooled_k_ws_ptr; // [batch, nhead_k, N_k, D] fp32
const void* sim_k_ws_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4: optional per-head topk (size = nhead_q floats).
// nullptr => use scalar `topk` for all heads.
const float* topk_per_head;
// R21B: optional per-head cdfthreshd (size = nhead_q floats).
// nullptr => use scalar `cdfthreshd` for all heads.
// Only consulted on topk<=0 path; bench currently always uses topk path.
const float* cdfthreshd_per_head;
};
CK_TILE_HOST static constexpr auto MakeKargs(const void* q_ptr,
@@ -74,7 +87,11 @@ struct SpargeBlockMapKernel
float scale,
void* block_map_ptr,
void* lut_ptr,
void* valid_block_num_ptr)
void* valid_block_num_ptr,
const void* pooled_k_ws_ptr,
const void* sim_k_ws_ptr,
const float* topk_per_head = nullptr,
const float* cdfthreshd_per_head = nullptr)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{q_ptr,
@@ -97,7 +114,11 @@ struct SpargeBlockMapKernel
block_map_ptr,
lut_ptr,
valid_block_num_ptr,
N_k};
pooled_k_ws_ptr,
sim_k_ws_ptr,
N_k,
topk_per_head,
cdfthreshd_per_head};
}
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_q, index_t seqlen_q)
@@ -174,6 +195,21 @@ struct SpargeBlockMapKernel
// Shared memory
__shared__ char smem[Pipeline::GetSmemSize()];
// R20 K-stat workspace: pre-offset for this (b, hk).
const index_t nhead_k = kargs.nhead_q / kargs.nhead_ratio_qk;
const index_t khead_off = (b * nhead_k + hk) * N_k;
const auto* pooled_k_ws =
reinterpret_cast<const float*>(kargs.pooled_k_ws_ptr) + khead_off * D;
const auto* sim_k_ws =
reinterpret_cast<const uint8_t*>(kargs.sim_k_ws_ptr) + khead_off;
// R21A Phase 4: per-head topk if provided, else scalar broadcast.
const float topk_eff =
(kargs.topk_per_head != nullptr) ? kargs.topk_per_head[hq] : kargs.topk;
// R21B: per-head cdfthreshd if provided, else scalar broadcast.
const float cdfthreshd_eff =
(kargs.cdfthreshd_per_head != nullptr) ? kargs.cdfthreshd_per_head[hq] : kargs.cdfthreshd;
Pipeline{}(q_window,
k_window,
kargs.seqlen_q,
@@ -182,12 +218,14 @@ struct SpargeBlockMapKernel
N_k,
kargs.nhead_ratio_qk,
kargs.simthreshd1,
kargs.cdfthreshd,
kargs.topk,
cdfthreshd_eff,
topk_eff,
kargs.scale,
bmap_ptr,
lut_out,
valid_out,
pooled_k_ws,
sim_k_ws,
static_cast<void*>(smem));
}
};

View File

@@ -0,0 +1,136 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include <type_traits>
namespace ck_tile {
// Kernel A wrapper: grid (N_k, nhead_k, batch). Each work-group precomputes
// K-block stats (pooled_k_mean[D], sim_k) for one (b, hk, kb) into a workspace
// that Kernel B (block_map) reads instead of recomputing per Q-block.
template <typename Pipeline_>
struct SpargeKStatsKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
using QDataType = typename Pipeline::QDataType;
using KDataType = typename Pipeline::KDataType;
static constexpr index_t kN0 = Pipeline::kN0;
static constexpr index_t D = Pipeline::D;
static constexpr index_t kAlignment = 16 / sizeof(KDataType);
struct Kargs
{
const void* k_ptr;
index_t seqlen_k;
index_t hdim_q;
index_t nhead_k;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
float simthreshd1;
void* pooled_k_ptr; // [batch, nhead_k, N_k, D] fp32
void* sim_k_ptr; // [batch, nhead_k, N_k] uint8
index_t N_k;
// R21A Phase 4 + R21B fix: optional per-head simthreshd1.
// Buffer is sized [nhead_q] floats to match SpargeAttn upstream contract
// (utils.py:324, Headnum=q.size(1)). Kernel only indexes the first
// nhead_k entries via [hk]. nullptr => use scalar `simthreshd1`.
const float* simthreshd1_per_head;
};
CK_TILE_HOST static constexpr auto MakeKargs(const void* k_ptr,
index_t seqlen_k,
index_t hdim_q,
index_t nhead_k,
index_t stride_k,
index_t nhead_stride_k,
index_t batch_stride_k,
float simthreshd1,
void* pooled_k_ptr,
void* sim_k_ptr,
const float* simthreshd1_per_head = nullptr)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return Kargs{k_ptr,
seqlen_k,
hdim_q,
nhead_k,
stride_k,
nhead_stride_k,
batch_stride_k,
simthreshd1,
pooled_k_ptr,
sim_k_ptr,
N_k,
simthreshd1_per_head};
}
CK_TILE_HOST static constexpr auto GridSize(index_t batch, index_t nhead_k, index_t seqlen_k)
{
const index_t N_k = integer_divide_ceil(seqlen_k, kN0);
return dim3(N_k, nhead_k, batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
const index_t kb = static_cast<index_t>(blockIdx.x);
const index_t hk = static_cast<index_t>(blockIdx.y);
const index_t b = static_cast<index_t>(blockIdx.z);
const auto* k_base = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
b * kargs.batch_stride_k + hk * kargs.nhead_stride_k +
kb * kN0 * kargs.stride_k;
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_base,
make_tuple(kargs.seqlen_k - kb * kN0, D),
make_tuple(kargs.stride_k, 1),
number<kAlignment>{},
number<1>{});
const auto k_dram = pad_tensor_view(
k_dram_naive, make_tuple(number<kN0>{}, number<D>{}), sequence<true, false>{});
auto k_window = make_tile_window(k_dram,
make_tuple(number<kN0>{}, number<D>{}),
{0, 0},
Pipeline::MakeKBlockDistribution());
const index_t N_k = kargs.N_k;
const index_t khead_off = (b * kargs.nhead_k + hk) * N_k;
auto* pooled_k_out = reinterpret_cast<float*>(kargs.pooled_k_ptr) + (khead_off + kb) * D;
auto* sim_k_out = reinterpret_cast<uint8_t*>(kargs.sim_k_ptr) + (khead_off + kb);
__shared__ char smem[Pipeline::GetSmemSize()];
// R21A Phase 4: per-head simthreshd1 if provided, else scalar broadcast.
const float simthreshd1_eff = (kargs.simthreshd1_per_head != nullptr)
? kargs.simthreshd1_per_head[hk]
: kargs.simthreshd1;
Pipeline{}(k_window,
kargs.seqlen_k,
kb,
simthreshd1_eff,
pooled_k_out,
sim_k_out,
static_cast<void*>(smem));
}
};
} // namespace ck_tile

View File

@@ -32,14 +32,22 @@ struct SpargeBlockMapPipeline
static constexpr index_t kMaxKBlocks = 1024;
// LDS layout (non-overlapping, all used simultaneously in Phase 2):
// [0 .. kReduceBytes) cross-warp reduction scratch
// [kScoreOffset ..) scores[N_k]
// [kBmapOffset ..) block_map[N_k]
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
static constexpr index_t kReduceBytes = NumWarps * D * sizeof(float);
static constexpr index_t kScoreOffset = kReduceBytes;
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
// [0 .. kReduceBytes) cross-warp reduction scratch slab 0
// [kReduceBytes .. 2*kReduceBytes) cross-warp reduction scratch slab 1
// (Round 8 b1: ping-pong for K-loop double buffer)
// [kScoreOffset ..) scores[N_k]
// [kBmapOffset ..) block_map[N_k]
// [kSmallOffset ..) Phase 3 argmax scratch (2*NumWarps floats)
// B2.v3 column-stride pad: replace k_idx*KPerThread with k_idx*(KPerThread+1)
// to break the 4-way intra-warp bank conflict. New per-warp slab size:
// KThreads * (KPerThread + 1) floats.
static constexpr index_t kColPaddedStride = KPerThread + 1;
static constexpr index_t kPerWarpFloats = KThreads * kColPaddedStride;
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
static constexpr index_t kReduceTotalBytes = 2 * kReduceBytes; // Round 8 b1: 2 slabs
static constexpr index_t kScoreOffset = kReduceTotalBytes;
static constexpr index_t kBmapOffset = kScoreOffset + kMaxKBlocks * sizeof(float);
static constexpr index_t kSmallOffset = kBmapOffset + kMaxKBlocks * sizeof(uint8_t);
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
@@ -98,6 +106,12 @@ struct SpargeBlockMapPipeline
}
// Cross-warp LDS reduction for column sums.
// Round 13f: templated TrailingSync flag. When false, the trailing __syncthreads()
// is dropped — only safe when the next access targets a *different* slab and the
// intervening work does not read smem_reduce. Used at the slab_b call in Phase 2
// K-loop, where the next iter's first cross-warp reduce writes to slab_a (different
// address) and is preceded by its own leading sync.
template <bool TrailingSync = true>
CK_TILE_DEVICE static void column_reduce_cross_warp(float (&col_acc)[KPerThread],
float* __restrict__ smem_reduce)
{
@@ -107,17 +121,21 @@ struct SpargeBlockMapPipeline
const index_t k_idx = lane_id % KThreads;
const index_t m_idx = lane_id / KThreads;
// B2.v3 column-stride pad: stride k_idx by (KPerThread+1)=9 instead of 8,
// changing per-lane bank from (k_idx*8+k)%32 to (k_idx*9+k)%32. For k=0,
// lanes (k_idx={0,4,8,12}) now hit banks {0,4,8,12} instead of all 0.
if(m_idx == 0)
for(index_t k = 0; k < KPerThread; ++k)
smem_reduce[warp_id * D + k_idx * KPerThread + k] = col_acc[k];
smem_reduce[warp_id * kPerWarpFloats + k_idx * kColPaddedStride + k] = col_acc[k];
__syncthreads();
for(index_t k = 0; k < KPerThread; ++k)
col_acc[k] = 0.f;
for(index_t w = 0; w < NumWarps; ++w)
for(index_t k = 0; k < KPerThread; ++k)
col_acc[k] += smem_reduce[w * D + k_idx * KPerThread + k];
__syncthreads();
col_acc[k] += smem_reduce[w * kPerWarpFloats + k_idx * kColPaddedStride + k];
if constexpr(TrailingSync)
__syncthreads();
}
// Compute ||v||^2 per row: sum along KPerThread then xor-shuffle across k_idx.
@@ -162,7 +180,8 @@ struct SpargeBlockMapPipeline
for(index_t m = 0; m < SeqPerThread; ++m)
{
float inv_norm = (row_norms[m] > 0.f) ? (1.0f / __builtin_sqrtf(row_norms[m])) : 0.f;
// Round 12: hardware fast rsqrt (v_rsq_f32, ~1 ULP) replaces sw sqrt+rcp.
float inv_norm = (row_norms[m] > 0.f) ? rsqrtf(row_norms[m]) : 0.f;
index_t gsq = m * (SeqThreadPerWarp * NumWarps) + warp_id * SeqThreadPerWarp + m_idx;
if(gsq < actual_seq)
for(index_t k = 0; k < KPerThread; ++k)
@@ -230,9 +249,9 @@ struct SpargeBlockMapPipeline
// ======================================================================
template <typename QWindowType, typename KWindowType>
CK_TILE_DEVICE void operator()(const QWindowType& q_window_in,
const KWindowType& k_window_in,
const KWindowType& /*k_window_in*/,
index_t seqlen_q,
index_t seqlen_k,
index_t /*seqlen_k*/,
index_t qb,
index_t N_k,
index_t /*nhead_ratio_qk*/,
@@ -243,11 +262,15 @@ struct SpargeBlockMapPipeline
uint8_t* block_map_ptr,
int32_t* lut_ptr,
int32_t* valid_block_num_ptr,
const float* __restrict__ pooled_k_ws_ptr,
const uint8_t* __restrict__ sim_k_ws_ptr,
void* smem_ptr) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
auto* smem_float = reinterpret_cast<float*>(smem_ptr);
// R20: K-loop no longer reduces, only Phase 1 uses smem_float0.
// smem_float1 slab is allocated for layout compat but unused.
auto* smem_float0 = reinterpret_cast<float*>(smem_ptr);
auto* smem_scores =
reinterpret_cast<float*>(reinterpret_cast<char*>(smem_ptr) + kScoreOffset);
auto* smem_bmap =
@@ -271,16 +294,22 @@ struct SpargeBlockMapPipeline
row_reduce_sq_norm<MPerThread>(q_data, psq, bs_q);
// 1b. Column sum -> mean
// Track F (re-apply R8 b2): drop trailing sync. Next reduce reuses same slab
// (smem_float0) and has its own leading __syncthreads() before reading.
// pooled_q_mean is register-only between reduces.
float pooled_q_mean[KPerThread];
column_reduce_thread_and_warp<MPerThread>(q_data, pooled_q_mean);
column_reduce_cross_warp(pooled_q_mean, smem_float);
column_reduce_cross_warp<false>(pooled_q_mean, smem_float0);
for(index_t k = 0; k < KPerThread; ++k)
pooled_q_mean[k] *= inv_bs_q;
// 1c. Normalised sum_hat
// Track F (re-apply R8 b2): drop trailing sync. Next cross-warp reduce in
// K-loop iter 0 writes slab_a=smem_float0 (kb=0 even). Although same slab,
// its leading __syncthreads() covers the WAR. sum_hat register-only here.
float sum_hat[KPerThread];
column_reduce_normalised<MPerThread>(q_data, psq, sum_hat, bs_q);
column_reduce_cross_warp(sum_hat, smem_float);
column_reduce_cross_warp<false>(sum_hat, smem_float0);
// 1d. sim_q = ||sum_hat||^2 / bs_q^2
float sh_sq = 0.f;
@@ -319,49 +348,34 @@ struct SpargeBlockMapPipeline
smem_bmap[i] = 0;
__syncthreads();
auto k_window = k_window_in;
// R20: K-stats precomputed by Kernel A. Each thread loads its own
// KPerThread-slice of pooled_k_mean from DRAM workspace; sim_k is a single
// byte. No K-tile load, no cross-warp reduce in the K-loop.
const index_t lane_id_kb = tid % WarpSize;
const index_t k_idx_kb = lane_id_kb % KThreads;
for(index_t kb = 0; kb < N_k; ++kb)
{
const index_t bs_k = min(static_cast<index_t>(kN0), seqlen_k - kb * kN0);
const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast<float>(bs_k)) : 0.f;
auto k_tile = load_tile(k_window);
float k_data[NPerThread * KPerThread];
tile_to_float<NPerThread * KPerThread>(k_tile, k_data);
// K mean
const float* p_kb = pooled_k_ws_ptr + kb * D + k_idx_kb * KPerThread;
float pooled_k_mean[KPerThread];
column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
column_reduce_cross_warp(pooled_k_mean, smem_float);
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] *= inv_bs_k;
pooled_k_mean[k] = p_kb[k];
// dot(pooled_q_mean, pooled_k_mean)
float dot = 0.f;
for(index_t k = 0; k < KPerThread; ++k)
dot += pooled_q_mean[k] * pooled_k_mean[k];
dot = reduce_across_k(dot);
// K L2 norms + normalised sum_hat
float k_psq[NPerThread];
row_reduce_sq_norm<NPerThread>(k_data, k_psq, bs_k);
float k_sum_hat[KPerThread];
column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
column_reduce_cross_warp(k_sum_hat, smem_float);
// sim_k
float ksh_sq = 0.f;
for(index_t k = 0; k < KPerThread; ++k)
ksh_sq += k_sum_hat[k] * k_sum_hat[k];
ksh_sq = reduce_across_k(ksh_sq);
const float denom_k = static_cast<float>(bs_k) * static_cast<float>(bs_k);
const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1);
const bool sim_k = (sim_k_ws_ptr[kb] != 0);
if(tid == 0)
{
// INVARIANT (mirrors SpargeAttn ref utils.py:175-180):
// ~sim_k blocks are forced ON in the bitmap (final_map[~sim_k]=1)
// AND have score = -inf so Phase 3 selection (topk / cdf) does NOT
// pick them again (would double-count toward topk budget).
// Both writes MUST stay together. Any Phase 3 selection rewrite
// (e.g. iterative argmax → bitonic sort) must keep the -inf write.
if(!sim_k)
{
smem_bmap[kb] = 1;
@@ -372,10 +386,8 @@ struct SpargeBlockMapPipeline
smem_scores[kb] = dot * scale;
}
}
__syncthreads();
move_tile_window(k_window, {kN0, 0});
}
__syncthreads(); // guard Phase 3's reads of smem_bmap / smem_scores
// ==================================================================
// Phase 3: Softmax + Selection
@@ -399,15 +411,24 @@ struct SpargeBlockMapPipeline
}
const float sum_exp = block_reduce_sum(lsum, smem_small);
// normalise
const float inv_sum = (sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
for(index_t i = tid; i < N_k; i += kBlockSize)
smem_scores[i] *= inv_sum;
__syncthreads();
// Round 13i: argmax is invariant under positive scaling (inv_sum > 0). When
// topk > 0 we never read normalised values for cdfthreshd, so skip the
// normalise pass entirely (saves N_k LDS writes + 1 __syncthreads). The
// cdfthreshd path (topk <= 0) still requires normalised scores so the
// accumulator `cumulative_prob` matches probabilities.
const bool topk_active = (topk > 0.f);
const float inv_sum =
(!topk_active && sum_exp > 0.f) ? (1.0f / sum_exp) : 0.f;
if(!topk_active)
{
for(index_t i = tid; i < N_k; i += kBlockSize)
smem_scores[i] *= inv_sum;
__syncthreads();
}
// Selection: iterative argmax
index_t num_to_select =
(topk > 0.f)
topk_active
? max(static_cast<index_t>(1), static_cast<index_t>(topk * static_cast<float>(N_k)))
: N_k;
@@ -448,6 +469,11 @@ struct SpargeBlockMapPipeline
}
__syncthreads();
// Round 13g: collapse 2 syncs/round into 1. tid==0 computes the global
// winner AND writes the sentinel (smem_bmap=1, smem_scores=-1) in the same
// critical section, gated by bv>0. All threads then read smem_small[0] for
// the early break / cumulative_prob accumulation. Saves 1 __syncthreads per
// round (~32 syncs @ N_k=64 topk=0.5).
if(tid == 0)
{
float bv = smem_small[0];
@@ -462,24 +488,22 @@ struct SpargeBlockMapPipeline
bi = wi;
}
}
// Write sentinel into bmap/scores in the same critical section.
// Guarded by bv > 0 so we never poison a valid score with -1.
if(bv > 0.f)
{
smem_bmap[bi] = 1;
smem_scores[bi] = -1.f;
}
smem_small[0] = bv;
smem_small[1] = bit_cast<float>(static_cast<int32_t>(bi));
}
__syncthreads();
float g_val = smem_small[0];
index_t g_idx = bit_cast<int32_t>(smem_small[1]);
float g_val = smem_small[0];
if(g_val <= 0.f)
break;
if(tid == 0)
{
smem_bmap[g_idx] = 1;
smem_scores[g_idx] = -1.f;
}
__syncthreads();
if(topk > 0.f)
{
if(round + 1 >= num_to_select)

View File

@@ -0,0 +1,110 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
namespace ck_tile {
// Kernel A of the K-stat precompute split: one work-group per (b, hk, kb)
// computes pooled_k_mean and sim_k for that K-block once. Kernel B then reads
// from the workspace instead of recomputing per Q-block.
template <typename Problem_>
struct SpargeKStatsPipeline
{
using Problem = remove_cvref_t<Problem_>;
using Base = SpargeBlockMapPipeline<Problem>;
using QDataType = typename Base::QDataType;
using KDataType = typename Base::KDataType;
static constexpr index_t kBlockSize = Base::kBlockSize;
static constexpr index_t kM0 = Base::kM0;
static constexpr index_t kN0 = Base::kN0;
static constexpr index_t D = Base::D;
static constexpr index_t NumWarps = Base::NumWarps;
static constexpr index_t WarpSize = Base::WarpSize;
static constexpr index_t KPerThread = Base::KPerThread;
static constexpr index_t KThreads = Base::KThreads;
static constexpr index_t SeqThreadPerWarp = Base::SeqThreadPerWarp;
static constexpr index_t NPerThread = Base::NPerThread;
static constexpr index_t kBlockPerCu = 1;
static constexpr index_t kColPaddedStride = Base::kColPaddedStride;
static constexpr index_t kPerWarpFloats = Base::kPerWarpFloats;
static constexpr index_t kReduceBytes = NumWarps * kPerWarpFloats * sizeof(float);
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return kReduceBytes; }
CK_TILE_HOST_DEVICE static constexpr auto MakeKBlockDistribution()
{
return Base::MakeKBlockDistribution();
}
// operator(): one work-group, one K-block. Writes D fp32 + 1 uint8 to workspace.
template <typename KWindowType>
CK_TILE_DEVICE void operator()(const KWindowType& k_window,
index_t seqlen_k,
index_t kb,
float simthreshd1,
float* __restrict__ pooled_k_out, // D floats
uint8_t* __restrict__ sim_k_out, // 1 byte
void* smem_ptr) const
{
const index_t tid = static_cast<index_t>(threadIdx.x);
auto* smem_reduce = reinterpret_cast<float*>(smem_ptr);
const index_t bs_k = min(static_cast<index_t>(kN0), seqlen_k - kb * kN0);
const float inv_bs_k = (bs_k > 0) ? (1.0f / static_cast<float>(bs_k)) : 0.f;
auto k_tile = load_tile(k_window);
float k_data[NPerThread * KPerThread];
Base::template tile_to_float<NPerThread * KPerThread>(k_tile, k_data);
const index_t warp_id = tid / WarpSize;
const index_t lane_id = tid % WarpSize;
const index_t k_idx = lane_id % KThreads;
const index_t m_idx = lane_id / KThreads;
// pooled_k_mean: column sum then cross-warp reduce.
// R21A: drop trailing sync (next cross_warp_reduce has its own leading sync).
float pooled_k_mean[KPerThread];
Base::template column_reduce_thread_and_warp<NPerThread>(k_data, pooled_k_mean);
Base::template column_reduce_cross_warp<false>(pooled_k_mean, smem_reduce);
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_mean[k] *= inv_bs_k;
// R21A: write pooled_k_mean to global early so its register liveness ends here,
// freeing VGPR before k_sum_hat becomes live.
if(warp_id == 0 && m_idx == 0)
{
for(index_t k = 0; k < KPerThread; ++k)
pooled_k_out[k_idx * KPerThread + k] = pooled_k_mean[k];
}
// K row L2 norms + normalised column sum (k_sum_hat)
float k_psq[NPerThread];
Base::template row_reduce_sq_norm<NPerThread>(k_data, k_psq, bs_k);
float k_sum_hat[KPerThread];
Base::template column_reduce_normalised<NPerThread>(k_data, k_psq, k_sum_hat, bs_k);
// R21A: drop trailing sync (no further smem read; only intra-warp shuffle + global write).
Base::template column_reduce_cross_warp<false>(k_sum_hat, smem_reduce);
// sim_k = (||k_sum_hat||^2 / bs_k^2) > simthreshd1
float ksh_sq = 0.f;
for(index_t k = 0; k < KPerThread; ++k)
ksh_sq += k_sum_hat[k] * k_sum_hat[k];
ksh_sq = Base::reduce_across_k(ksh_sq);
const float denom_k = static_cast<float>(bs_k) * static_cast<float>(bs_k);
const bool sim_k = (denom_k > 0.f) && ((ksh_sq / denom_k) > simthreshd1);
if(tid == 0)
*sim_k_out = sim_k ? static_cast<uint8_t>(1) : static_cast<uint8_t>(0);
}
};
} // namespace ck_tile