mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
sparse_attn: split KStats kernel, add README + perf charts
- Split SpargeKStatsKernel/Pipeline out of BlockMap (Kernel A produces
per-block K stats workspace consumed by Kernel B), removing redundant
K-stat recomputation across Q-blocks.
- Add example/ck_tile/50_sparse_attn/README.md (status vs upstream pinned
to ae5b629, unported items, usage, references).
- Add example/ck_tile/50_sparse_attn/docs/{speedup_vs_sparsity,kernel_breakdown}.png
+ reusable plot_sparge_perf.py (b=2 h=32 s=16384 d=128 fp16 perf snapshot).
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
This commit is contained in:
45
example/ck_tile/50_sparse_attn/README.md
Normal file
45
example/ck_tile/50_sparse_attn/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# Sparge Attention (Composable Kernel)
|
||||
|
||||
A Composable Kernel port of [SpargeAttn](https://github.com/thu-ml/SpargeAttn) for AMD GPU. Both the block-map pipeline (mean-pool → cosine sim → pooled QK → top-k LUT) and the sparse FMHA stage run on-GPU. Two attention backends are exposed via `-pipeline=vsa` (default, faster) and `-pipeline=jenga` (async K/V load variant).
|
||||
|
||||
## Status vs Upstream
|
||||
|
||||
Implemented:
|
||||
- per-block mean-pool, cosine similarity, pooled QK
|
||||
- top-k / `cdfthreshd` block selection, BlockMap LUT
|
||||
- sparse FMHA (both `vsa` and `jenga` backends)
|
||||
- per-head `topk` / `simthreshd1` / `cdfthreshd`
|
||||
|
||||
Not yet ported (upstream pinned to commit [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a)):
|
||||
- **K smoothing** — pre-pool `k -= km`; required for diffusion / video checkpoints (CogVideoX, Mochi-1, Flux, OpenSora, SD 3.5) ([spas_sage_attn/core.py:L53](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L53))
|
||||
- **is_causal mask in pooled score** — required for causal-LM prefill (Llama, Qwen) ([spas_sage_attn/utils.py:L338](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L338))
|
||||
- **attention_sink** — column 0 forced ON; upstream is hard-wired to `True` at inference ([spas_sage_attn/autotune.py:L355](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/autotune.py#L355))
|
||||
- **pv_threshold per-Q-tile skip in attn kernel** — pure perf, ~5–15% on the dominant attention slice ([spas_sage_attn/core.py:L265](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/core.py#L265))
|
||||
- **Sort-based top-k selection** — replaces our O(N_k^2) iterative argmax; matters at long seqlen (s ≥ 16k) ([spas_sage_attn/utils.py:L345](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L345))
|
||||
- **Q/K int8 quant fusion in pool kernel** — enables a downstream int8 GEMM0 in the attn kernel ([spas_sage_attn/utils.py:L371](https://github.com/thu-ml/SpargeAttn/blob/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a/spas_sage_attn/utils.py#L371))
|
||||
|
||||
## Performance
|
||||
|
||||
At b=2 h=32 s=16384 fp16, sparge (vsa backend) reaches **1.78× FMHA throughput at topk=0.4** and **5.04× at topk=0.1**, and stays above 1.0× across the full topk range.
|
||||
|
||||

|
||||
|
||||
*Speedup vs FMHA, b=2 h=32 s=16384 d=128 fp16. Shape chosen to match Fig. 10 of the SpargeAttn paper ([arXiv:2502.18137](https://arxiv.org/abs/2502.18137); Mochi-1, 22K context, head_dim=128); s=16384 is the closest grid point. Gray-outlined points have >30% inter-rep spread.*
|
||||
|
||||

|
||||
|
||||
*BlockMap (`_pre`) stacked on attention (`_attn`), b=2 h=32 d=128 fp16 topk=0.4. BlockMap is roughly 17% of total at s=16384.*
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
ninja tile_example_sparge
|
||||
./bin/tile_example_sparge -pipeline=vsa -b=2 -h=32 -s=16384 -d=128 -topk=0.4 -simthreshd1=0.001
|
||||
```
|
||||
|
||||
Add `-v=1` for CPU validation; use a small shape (`-b=1 -h=2 -s=512`), since full-shape CPU reference scales O(s²) and runs 30+ minutes at s=8k, hours at s=16k.
|
||||
|
||||
## References
|
||||
|
||||
- [SpargeAttn upstream](https://github.com/thu-ml/SpargeAttn) (pinned to [`ae5b629`](https://github.com/thu-ml/SpargeAttn/tree/ae5b629ebb41e41f86b3ea2ab5a3283f13ac151a))
|
||||
- [Paper — Zhang et al., arXiv:2502.18137](https://arxiv.org/abs/2502.18137)
|
||||
BIN
example/ck_tile/50_sparse_attn/docs/kernel_breakdown.png
Normal file
BIN
example/ck_tile/50_sparse_attn/docs/kernel_breakdown.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 83 KiB |
258
example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py
Normal file
258
example/ck_tile/50_sparse_attn/docs/plot_sparge_perf.py
Normal file
@@ -0,0 +1,258 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Plot sparge perf charts from full_grid.csv.
|
||||
|
||||
Re-run with different fixed (b, h, s, dtype, topk) by editing the constants below.
|
||||
No GPU / no srun / no rebuild — pure matplotlib from CSV.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Tunable constants — edit these to regenerate for a different point.
|
||||
# ----------------------------------------------------------------------
|
||||
CSV_PATH = "/home/AMD/ginolu12/gino_tmp/full_grid.csv"
|
||||
OUT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Chart 1 — speedup vs topk for one fixed (b, h, s, dtype)
|
||||
CHART1_B = 2
|
||||
CHART1_H = 32
|
||||
CHART1_S = 16384
|
||||
CHART1_DTYPE = "fp16"
|
||||
CHART1_HEAD_DIM = 128 # for title only
|
||||
|
||||
# Chart 2 — kernel breakdown across s for fixed (b, h, dtype, topk)
|
||||
CHART2_B = 2
|
||||
CHART2_H = 32
|
||||
CHART2_DTYPE = "fp16"
|
||||
CHART2_TOPK = 0.4
|
||||
CHART2_S_LIST = [2048, 4096, 8192, 16384]
|
||||
CHART2_HEAD_DIM = 128 # for title only
|
||||
|
||||
DPI = 140
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ----------------------------------------------------------------------
|
||||
def is_fail(note):
|
||||
if not isinstance(note, str):
|
||||
return False
|
||||
return "FAIL" in note
|
||||
|
||||
def is_high_spread(note):
|
||||
if not isinstance(note, str):
|
||||
return False
|
||||
return "HIGH_SPREAD" in note
|
||||
|
||||
def load_data():
|
||||
df = pd.read_csv(CSV_PATH)
|
||||
return df
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Chart 1
|
||||
# ----------------------------------------------------------------------
|
||||
def plot_chart1(df, out_path):
|
||||
sel = df[
|
||||
(df["b"] == CHART1_B)
|
||||
& (df["h"] == CHART1_H)
|
||||
& (df["s"] == CHART1_S)
|
||||
& (df["dtype"] == CHART1_DTYPE)
|
||||
].copy()
|
||||
sel = sel.sort_values("topk").reset_index(drop=True)
|
||||
|
||||
if sel.empty:
|
||||
print(f"[chart1] WARNING: no rows for b={CHART1_B} h={CHART1_H} s={CHART1_S} dtype={CHART1_DTYPE}")
|
||||
return [], 0
|
||||
|
||||
# Drop fully failed rows but keep partial-fail rows; we'll mask per-series.
|
||||
# Convert numeric columns
|
||||
for col in ["sparge_jenga", "sparge_vsa", "sparse_jenga", "sparse_vsa", "fmha_us"]:
|
||||
sel[col] = pd.to_numeric(sel[col], errors="coerce")
|
||||
|
||||
fmha = sel["fmha_us"]
|
||||
|
||||
# Compute speedups; rows with FAIL on a given column will have NaN already.
|
||||
series = {
|
||||
"sparge_vsa": fmha / sel["sparge_vsa"],
|
||||
"sparge_jenga": fmha / sel["sparge_jenga"],
|
||||
"sparse_vsa": fmha / sel["sparse_vsa"],
|
||||
"sparse_jenga": fmha / sel["sparse_jenga"],
|
||||
}
|
||||
|
||||
style = {
|
||||
"sparge_vsa": {"color": "#1f77b4", "marker": "o", "lw": 2.0},
|
||||
"sparge_jenga": {"color": "#ff7f0e", "marker": "s", "lw": 2.0},
|
||||
"sparse_vsa": {"color": "#2ca02c", "marker": "^", "lw": 1.5, "ls": "--"},
|
||||
"sparse_jenga": {"color": "#d62728", "marker": "v", "lw": 1.5, "ls": "--"},
|
||||
}
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8.5, 5.5), dpi=DPI)
|
||||
|
||||
x = sel["topk"].to_numpy()
|
||||
|
||||
# HIGH_SPREAD overlay first (under main markers)
|
||||
hs_mask = sel["note"].apply(is_high_spread)
|
||||
high_spread_cells = []
|
||||
if hs_mask.any():
|
||||
for _, row in sel[hs_mask].iterrows():
|
||||
high_spread_cells.append((row["topk"], row["max_spread_pct"]))
|
||||
# gray ring underneath every series's data point at that x
|
||||
for label, sp in series.items():
|
||||
xs_hs = x[hs_mask.to_numpy()]
|
||||
ys_hs = sp[hs_mask.to_numpy()].to_numpy()
|
||||
ax.scatter(xs_hs, ys_hs, s=180, facecolors="none",
|
||||
edgecolors="gray", linewidths=1.5, zorder=2)
|
||||
|
||||
for label, sp in series.items():
|
||||
st = style[label]
|
||||
ax.plot(x, sp.to_numpy(), label=label,
|
||||
color=st["color"], marker=st["marker"],
|
||||
linewidth=st["lw"], linestyle=st.get("ls", "-"),
|
||||
markersize=7, zorder=3)
|
||||
|
||||
ax.axhline(1.0, color="black", linestyle=":", linewidth=1.2, label="fmha (baseline)", zorder=1)
|
||||
|
||||
ax.set_xlabel("topk (kept fraction)")
|
||||
ax.set_ylabel("speedup vs FMHA dense (×)")
|
||||
ax.set_title(
|
||||
f"Speedup vs FMHA "
|
||||
f"(b={CHART1_B} h={CHART1_H} s={CHART1_S} d={CHART1_HEAD_DIM} {CHART1_DTYPE})"
|
||||
)
|
||||
ax.grid(True, which="both", linestyle=":", alpha=0.6)
|
||||
ax.set_xticks(np.arange(0.1, 0.71, 0.1))
|
||||
ax.legend(loc="best", framealpha=0.9)
|
||||
|
||||
# Footnote about HIGH_SPREAD overlay
|
||||
if high_spread_cells:
|
||||
ax.text(0.01, -0.16,
|
||||
"Gray rings: HIGH_SPREAD cells (high run-to-run variance)",
|
||||
transform=ax.transAxes, fontsize=8, color="gray")
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=DPI, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return high_spread_cells, os.path.getsize(out_path)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Chart 2
|
||||
# ----------------------------------------------------------------------
|
||||
def plot_chart2(df, out_path):
|
||||
sel = df[
|
||||
(df["b"] == CHART2_B)
|
||||
& (df["h"] == CHART2_H)
|
||||
& (df["dtype"] == CHART2_DTYPE)
|
||||
& (np.isclose(df["topk"], CHART2_TOPK))
|
||||
& (df["s"].isin(CHART2_S_LIST))
|
||||
].copy()
|
||||
sel = sel.sort_values("s").reset_index(drop=True)
|
||||
|
||||
if sel.empty:
|
||||
print(f"[chart2] WARNING: no rows for b={CHART2_B} h={CHART2_H} dtype={CHART2_DTYPE} topk={CHART2_TOPK}")
|
||||
return 0
|
||||
|
||||
for col in ["sparge_jenga_pre", "sparge_jenga_attn",
|
||||
"sparge_vsa_pre", "sparge_vsa_attn", "fmha_us"]:
|
||||
sel[col] = pd.to_numeric(sel[col], errors="coerce")
|
||||
|
||||
s_vals = sel["s"].to_numpy()
|
||||
n = len(s_vals)
|
||||
idx = np.arange(n, dtype=float)
|
||||
|
||||
width = 0.35
|
||||
offset = width / 2 + 0.02
|
||||
|
||||
fig, ax = plt.subplots(figsize=(9.0, 5.8), dpi=DPI)
|
||||
|
||||
# Jenga bars (left of group)
|
||||
jenga_pre = sel["sparge_jenga_pre"].to_numpy()
|
||||
jenga_attn = sel["sparge_jenga_attn"].to_numpy()
|
||||
vsa_pre = sel["sparge_vsa_pre"].to_numpy()
|
||||
vsa_attn = sel["sparge_vsa_attn"].to_numpy()
|
||||
fmha_vals = sel["fmha_us"].to_numpy()
|
||||
|
||||
color_jenga_pre = "#fdbf6f" # light orange
|
||||
color_jenga_attn = "#ff7f0e" # orange
|
||||
color_vsa_pre = "#a6cee3" # light blue
|
||||
color_vsa_attn = "#1f77b4" # blue
|
||||
|
||||
bj_pre = ax.bar(idx - offset, jenga_pre, width,
|
||||
color=color_jenga_pre, edgecolor="black", linewidth=0.6,
|
||||
label="sparge_jenga _pre (BlockMap)")
|
||||
bj_at = ax.bar(idx - offset, jenga_attn, width, bottom=jenga_pre,
|
||||
color=color_jenga_attn, edgecolor="black", linewidth=0.6,
|
||||
label="sparge_jenga _attn")
|
||||
bv_pre = ax.bar(idx + offset, vsa_pre, width,
|
||||
color=color_vsa_pre, edgecolor="black", linewidth=0.6,
|
||||
label="sparge_vsa _pre (BlockMap)")
|
||||
bv_at = ax.bar(idx + offset, vsa_attn, width, bottom=vsa_pre,
|
||||
color=color_vsa_attn, edgecolor="black", linewidth=0.6,
|
||||
label="sparge_vsa _attn")
|
||||
|
||||
# Add total labels on top of each stack
|
||||
totals_jenga = jenga_pre + jenga_attn
|
||||
totals_vsa = vsa_pre + vsa_attn
|
||||
for i in range(n):
|
||||
ax.text(idx[i] - offset, totals_jenga[i], f"{totals_jenga[i]:.0f}",
|
||||
ha="center", va="bottom", fontsize=8)
|
||||
ax.text(idx[i] + offset, totals_vsa[i], f"{totals_vsa[i]:.0f}",
|
||||
ha="center", va="bottom", fontsize=8)
|
||||
|
||||
# FMHA reference: short horizontal dashed segment per group
|
||||
seg_half = 0.40
|
||||
fmha_label_done = False
|
||||
for i in range(n):
|
||||
ax.hlines(fmha_vals[i], idx[i] - seg_half, idx[i] + seg_half,
|
||||
colors="black", linestyles="dashed", linewidth=1.2,
|
||||
label="fmha dense (reference)" if not fmha_label_done else None,
|
||||
zorder=5)
|
||||
ax.text(idx[i] + seg_half + 0.02, fmha_vals[i],
|
||||
f"fmha {fmha_vals[i]:.0f}", fontsize=7, va="center", color="black")
|
||||
fmha_label_done = True
|
||||
|
||||
ax.set_xticks(idx)
|
||||
ax.set_xticklabels([f"s={s}" for s in s_vals.astype(int)])
|
||||
ax.set_xlabel("sequence length (s)")
|
||||
ax.set_ylabel("kernel time (µs)")
|
||||
ax.set_title(
|
||||
f"Sparge kernel time breakdown "
|
||||
f"(b={CHART2_B} h={CHART2_H} d={CHART2_HEAD_DIM} {CHART2_DTYPE}, topk={CHART2_TOPK})"
|
||||
)
|
||||
ax.grid(True, axis="y", linestyle=":", alpha=0.6)
|
||||
ax.legend(loc="upper left", framealpha=0.9, fontsize=9)
|
||||
|
||||
# log-y is too aggressive — leave linear; bars will just be tall.
|
||||
fig.tight_layout()
|
||||
fig.savefig(out_path, dpi=DPI, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return os.path.getsize(out_path)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Main
|
||||
# ----------------------------------------------------------------------
|
||||
def main():
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
df = load_data()
|
||||
|
||||
chart1_path = os.path.join(OUT_DIR, "speedup_vs_sparsity.png")
|
||||
chart2_path = os.path.join(OUT_DIR, "kernel_breakdown.png")
|
||||
|
||||
hs_cells, size1 = plot_chart1(df, chart1_path)
|
||||
size2 = plot_chart2(df, chart2_path)
|
||||
|
||||
print(f"Wrote {chart1_path} ({size1} bytes)")
|
||||
print(f"Wrote {chart2_path} ({size2} bytes)")
|
||||
|
||||
if hs_cells:
|
||||
print("HIGH_SPREAD cells in chart-1 selection:")
|
||||
for topk, pct in hs_cells:
|
||||
print(f" topk={topk} max_spread_pct={pct}")
|
||||
else:
|
||||
print("No HIGH_SPREAD cells in chart-1 selection.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
BIN
example/ck_tile/50_sparse_attn/docs/speedup_vs_sparsity.png
Normal file
BIN
example/ck_tile/50_sparse_attn/docs/speedup_vs_sparsity.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 124 KiB |
@@ -5,6 +5,9 @@
|
||||
#include "sparge_blockmap_trek.hpp"
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iostream>
|
||||
|
||||
// ============================================================================
|
||||
@@ -61,6 +64,9 @@ using bmap_fp16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::half_t, //
|
||||
using bmap_fp16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_fp16_problem>;
|
||||
using bmap_fp16_kernel = ck_tile::SpargeBlockMapKernel<bmap_fp16_pipeline>;
|
||||
|
||||
using kstats_fp16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_fp16_problem>;
|
||||
using kstats_fp16_kernel = ck_tile::SpargeKStatsKernel<kstats_fp16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// bf16: D=128, kM0=64, kN0=128
|
||||
// ============================================================================
|
||||
@@ -112,6 +118,78 @@ using bmap_bf16_problem = ck_tile::BlockFmhaPipelineProblem<ck_tile::bf16_t, //
|
||||
using bmap_bf16_pipeline = ck_tile::SpargeBlockMapPipeline<bmap_bf16_problem>;
|
||||
using bmap_bf16_kernel = ck_tile::SpargeBlockMapKernel<bmap_bf16_pipeline>;
|
||||
|
||||
using kstats_bf16_pipeline = ck_tile::SpargeKStatsPipeline<bmap_bf16_problem>;
|
||||
using kstats_bf16_kernel = ck_tile::SpargeKStatsKernel<kstats_bf16_pipeline>;
|
||||
|
||||
// ============================================================================
|
||||
// Internal K-stat workspace (R20): process-lifetime lazy hipMalloc, sized
|
||||
// to the largest (batch, nhead_k, N_k, D) seen so far. Caller API unchanged.
|
||||
// ============================================================================
|
||||
|
||||
namespace {
|
||||
|
||||
struct KStatsWorkspace
|
||||
{
|
||||
void* pooled_k_dev = nullptr; // [batch, nhead_k, N_k, D] fp32
|
||||
void* sim_k_dev = nullptr; // [batch, nhead_k, N_k] uint8
|
||||
size_t pooled_k_bytes = 0;
|
||||
size_t sim_k_bytes = 0;
|
||||
|
||||
void ensure(int batch, int nhead_k, int N_k, int D)
|
||||
{
|
||||
const size_t need_p = static_cast<size_t>(batch) * nhead_k * N_k * D * sizeof(float);
|
||||
const size_t need_s = static_cast<size_t>(batch) * nhead_k * N_k * sizeof(uint8_t);
|
||||
if(need_p > pooled_k_bytes)
|
||||
{
|
||||
if(pooled_k_dev != nullptr) (void)hipFree(pooled_k_dev);
|
||||
(void)hipMalloc(&pooled_k_dev, need_p);
|
||||
pooled_k_bytes = need_p;
|
||||
}
|
||||
if(need_s > sim_k_bytes)
|
||||
{
|
||||
if(sim_k_dev != nullptr) (void)hipFree(sim_k_dev);
|
||||
(void)hipMalloc(&sim_k_dev, need_s);
|
||||
sim_k_bytes = need_s;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
KStatsWorkspace& g_kstats_ws()
|
||||
{
|
||||
static KStatsWorkspace ws;
|
||||
return ws;
|
||||
}
|
||||
|
||||
template <typename KStatsKernel, typename BlockMapKernel>
|
||||
void launch_kstats_then_blockmap(sparge_blockmap_args args, const ck_tile::stream_config& s)
|
||||
{
|
||||
const int N_k = ck_tile::integer_divide_ceil(args.seqlen_k, BlockMapKernel::kN0);
|
||||
const int D = BlockMapKernel::D;
|
||||
auto& ws = g_kstats_ws();
|
||||
ws.ensure(args.batch, args.nhead_k, N_k, D);
|
||||
|
||||
// Stage 1: K stats
|
||||
{
|
||||
auto [kargs, grids] =
|
||||
sparge_kstats_create_kargs_and_grids<KStatsKernel>(args, ws.pooled_k_dev, ws.sim_k_dev);
|
||||
const dim3 blocks = KStatsKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = KStatsKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(KStatsKernel{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
// Stage 2: block_map (reads ws)
|
||||
{
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<BlockMapKernel>(
|
||||
args, ws.pooled_k_dev, ws.sim_k_dev);
|
||||
const dim3 blocks = BlockMapKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = BlockMapKernel::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(BlockMapKernel{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// ============================================================================
|
||||
// Dispatch
|
||||
// ============================================================================
|
||||
@@ -122,26 +200,20 @@ float sparge_blockmap_fwd(sparge_blockmap_traits traits,
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_fp16_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_fp16_d128" << std::flush;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
|
||||
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s_);
|
||||
});
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_bf16_kernel;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", sparge_blockmap_bf16_d128" << std::flush;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
return ck_tile::launch_kernel(s, [=](const ck_tile::stream_config& s_) {
|
||||
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s_);
|
||||
});
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
@@ -160,23 +232,13 @@ void sparge_blockmap_fwd_oneshot(sparge_blockmap_traits traits,
|
||||
{
|
||||
if(traits.data_type == "fp16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_fp16_kernel;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
launch_kstats_then_blockmap<kstats_fp16_kernel, bmap_fp16_kernel>(args, s);
|
||||
return;
|
||||
}
|
||||
|
||||
if(traits.data_type == "bf16" && traits.hdim_q == 128)
|
||||
{
|
||||
using k_ = bmap_bf16_kernel;
|
||||
auto [kargs, grids] = sparge_blockmap_create_kargs_and_grids<k_>(args);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
launch_kstats_then_blockmap<kstats_bf16_kernel, bmap_bf16_kernel>(args, s);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,9 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/sparge_blockmap_pipeline.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/pipeline/sparge_kstats_pipeline.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/kernel/sparge_blockmap_kernel.hpp"
|
||||
#include "ck_tile/ops/sparse_attn/kernel/sparge_kstats_kernel.hpp"
|
||||
|
||||
#include "fmha_fwd_trek.hpp"
|
||||
|
||||
@@ -45,6 +47,15 @@ struct sparge_blockmap_args
|
||||
void* block_map_ptr;
|
||||
void* lut_ptr;
|
||||
void* valid_block_num_ptr;
|
||||
|
||||
// R21A Phase 4 + R21B fix: optional per-head superparams. nullptr => use scalar.
|
||||
// Buffer sizes match SpargeAttn upstream contract (utils.py:324-328: all sized
|
||||
// by Headnum=q.size(1)=nhead_q). K-side kernel still indexes [hk] into the
|
||||
// first nhead_k entries — for MHA equivalent to old [nhead_k] sizing, for
|
||||
// MQA/GQA aligns to upstream tuned ckpt layout.
|
||||
const float* simthreshd1_per_head_ptr = nullptr; // size = nhead_q floats (kernel reads [0..nhead_k-1])
|
||||
const float* cdfthreshd_per_head_ptr = nullptr; // size = nhead_q floats
|
||||
const float* topk_per_head_ptr = nullptr; // size = nhead_q floats
|
||||
};
|
||||
|
||||
struct sparge_blockmap_traits
|
||||
@@ -57,7 +68,9 @@ struct sparge_blockmap_traits
|
||||
// Create kernel args and grid dimensions
|
||||
// ============================================================================
|
||||
template <typename BlockMapKernel>
|
||||
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
|
||||
auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args,
|
||||
const void* pooled_k_ws_ptr,
|
||||
const void* sim_k_ws_ptr)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = BlockMapKernel::MakeKargs(args.q_ptr,
|
||||
@@ -79,12 +92,38 @@ auto sparge_blockmap_create_kargs_and_grids(sparge_blockmap_args args)
|
||||
args.scale,
|
||||
args.block_map_ptr,
|
||||
args.lut_ptr,
|
||||
args.valid_block_num_ptr);
|
||||
args.valid_block_num_ptr,
|
||||
pooled_k_ws_ptr,
|
||||
sim_k_ws_ptr,
|
||||
args.topk_per_head_ptr,
|
||||
args.cdfthreshd_per_head_ptr);
|
||||
|
||||
dim3 grids = BlockMapKernel::GridSize(args.batch, args.nhead_q, args.seqlen_q);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
template <typename KStatsKernel>
|
||||
auto sparge_kstats_create_kargs_and_grids(sparge_blockmap_args args,
|
||||
void* pooled_k_ws_ptr,
|
||||
void* sim_k_ws_ptr)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = KStatsKernel::MakeKargs(args.k_ptr,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
args.nhead_k,
|
||||
args.stride_k,
|
||||
args.nhead_stride_k,
|
||||
args.batch_stride_k,
|
||||
args.simthreshd1,
|
||||
pooled_k_ws_ptr,
|
||||
sim_k_ws_ptr,
|
||||
args.simthreshd1_per_head_ptr);
|
||||
|
||||
dim3 grids = KStatsKernel::GridSize(args.batch, args.nhead_k, args.seqlen_k);
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Hand-written template instantiation dispatch
|
||||
// ============================================================================
|
||||
|
||||
@@ -105,7 +105,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("seed", "42", "random seed")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations")
|
||||
.insert("kname", "0", "print kernel name");
|
||||
.insert("kname", "0", "print kernel name")
|
||||
.insert("perhead", "0",
|
||||
"R21A Phase 4: 0=scalar (default), 1=per-head [H] superparam test "
|
||||
"(varies topk[h] = topk * (1 + 0.5*(h - H/2)/H), simthreshd1 unchanged)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -135,6 +138,7 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int perhead = arg_parser.get_int("perhead");
|
||||
|
||||
if(nhead_k < 0) nhead_k = nhead;
|
||||
if(seqlen_k < 0) seqlen_k = seqlen_q;
|
||||
@@ -231,6 +235,33 @@ bool run_test(const ck_tile::ArgParser& arg_parser)
|
||||
bmap_args.lut_ptr = (pipeline == "vsa") ? lut_dev.GetDeviceBuffer() : nullptr;
|
||||
bmap_args.valid_block_num_ptr = (pipeline == "vsa") ? valid_bn_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
// R21A Phase 4 + R21B fix: per-head superparam buffers, all sized [nhead_q]
|
||||
// to match SpargeAttn upstream contract (utils.py:324-328, Headnum=q.size(1)).
|
||||
// K-side kernel reads only the first nhead_k entries via [hk].
|
||||
ck_tile::DeviceMem topk_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem sim1_per_head_dev(static_cast<size_t>(nhead) * sizeof(float));
|
||||
ck_tile::DeviceMem cdf_per_head_dev (static_cast<size_t>(nhead) * sizeof(float));
|
||||
if(perhead != 0)
|
||||
{
|
||||
std::vector<float> topk_h(nhead);
|
||||
std::vector<float> sim1_h(nhead);
|
||||
std::vector<float> cdf_h (nhead);
|
||||
for(int h = 0; h < nhead; ++h)
|
||||
{
|
||||
// small per-head jitter around scalar topk so sparsity differs by head
|
||||
const float jitter = 0.5f * (static_cast<float>(h - nhead / 2) / nhead);
|
||||
topk_h[h] = topk * (1.0f + jitter);
|
||||
sim1_h[h] = simthreshd1; // bit-identical to scalar (kernel reads [0..nhead_k-1])
|
||||
cdf_h[h] = cdfthreshd;
|
||||
}
|
||||
topk_per_head_dev.ToDevice(topk_h.data());
|
||||
sim1_per_head_dev.ToDevice(sim1_h.data());
|
||||
cdf_per_head_dev .ToDevice(cdf_h.data());
|
||||
bmap_args.topk_per_head_ptr = static_cast<const float*>(topk_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.simthreshd1_per_head_ptr = static_cast<const float*>(sim1_per_head_dev.GetDeviceBuffer());
|
||||
bmap_args.cdfthreshd_per_head_ptr = static_cast<const float*>(cdf_per_head_dev.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// ---- build attention args ----
|
||||
ck_tile::stream_config stream_cfg;
|
||||
stream_cfg.stream_id_ = nullptr;
|
||||
|
||||
Reference in New Issue
Block a user