Commit Graph

1011 Commits

Author SHA1 Message Date
juuso-oskari
9d7cc3ee9e CK-UA: extend FP8 to the 16x16x32 _m16 decode tier via LDS roundtrip
The 32x32x16 tiers (prefill_d{64,128}, decode_d{64,128}_m{32,64,128}) keep
the cheap in-register `ds_bpermute_b32` cross-lane swap that fixes the
QK-C / PV-A per-thread alias for the union'd `sp_compute` / `p`.

The 16x16x32 m16 tiers (decode_d{64,128}_m16) cannot use the swap -- the
MFMA puts the paired-lane bit at a different position and the
sub=0/sub=1 4-fp8 chunks no longer map onto each other. We add a
layout-agnostic LDS roundtrip as the `else` branch, gated by the same
`PVWarpTile` constexpr:

  - Hoist two distribution-bound windows over the existing `p_lds`
    region (one bound to the QK-C output distribution, one to the PV-A
    input distribution). Done once per kernel invocation.
  - In `fmha_alu1`, after the cvt_pk_fp8_f32 packing chain, view the
    union's bytes as a `static_distributed_tensor<fp8>` in the QK-C
    distribution, `store_tile` it through `p_lds` in canonical (M, N)
    order, `s_barrier`, then `load_tile` back with the PV-A
    distribution and copy into `sp(idx).p`.

A/B'd a uniform LDS-roundtrip (no fast-path) vs the split: pure LDS
regressed decode_m128 by ~1.5x end-to-end (CK FP8 dropped from
~0.39x of Triton FP8 to ~0.16x), driven by the extra block-wide
barrier on the 4-warp decode path. Keeping the swap for 32x32x16
preserves the previously-tuned perf.

Dispatcher (`unified_attention.cpp`) now FP8-enables every UA variant
including decode_d{64,128}_m16. Four new instance .cpp files
(`unified_attention_d{64,128}_fp8_{mask,nmask}_decode_t.cpp`)
instantiate the m16 FP8 kernels.

Pytest (`test_unified_attention_ck_correctness.py`):
  - 245 BF16/FP16: pass (no regression from the pipeline edit).
  - 160 FP8: pass (was 112 before m16 enablement).
  - 80 skipped: block_size<32 or query_len>kv_len -- pre-existing.

Single-shape m16 dispatches verified on gfx950:
  b=128 sq=1 hq=hk=8 d=128 fp8 PASS  (CK 0.109 ms / Triton 0.043 ms)
  b=128 sq=1 hq=hk=8 d=64  fp8 PASS  (CK 0.077 ms / Triton 0.039 ms)

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 20:00:35 +00:00
juuso-oskari
63c75277a0 CK-UA: enable FP8 (e4m3) for prefill/m128 and the 32x32x16 small-tile decode variants
Full pipeline support for FP8 (e4m3fn on gfx950 / e4m3fnuz on gfx942)
in the unified-attention kernel, gated to the 32x32x16 MFMA tiers in
both d=64 and d=128 ladders: prefill_d{64,128}, decode_d{64,128}_m128,
decode_d128_m32, and decode_d64_m64. The 16x16x32 _m16 tiers stay
BF16/FP16-only -- the QK-C and PV-A per-thread layouts there differ
by an M<->N swap that the current slot-swap fixup cannot express; a
full per-thread transpose (most likely via LDS) is needed.

Pipeline (unified_attention_pipeline.hpp):
* `fmha_alu1` now performs a cross-lane P-tile fixup right after the
  FP8 packing of softmax(P). It's a `ds_bpermute_b32` between paired
  lanes `lane ^ 32`, swapping sub=0 slot[k_base+4..k_base+7] with
  sub=1 slot[k_base..k_base+3] for every 8-fp8 chunk. This realigns
  the FP8 packed P operand with PV-A's `Single` AttrNumAccess
  per-thread layout, which is necessary because the QK-C output and
  PV-A input alias byte-for-byte via the sp_compute/p union -- and
  for FP8 the two warp-gemm layouts no longer agree (BF16/FP16 keep
  Double AttrNumAccess in the PV gemm, which matches QK-C natively).
  Gated on `Gemm1WarpTile == 32x32x16`; FP8-only (BF16/FP16 paths take
  the existing cvt_pk path unchanged).

Default policy (unified_attention_pipeline_default_policy.hpp):
* PV warp gemm now selects `WGAttrNumAccessEnum::Single` when V is
  fp8/bf8 and `Double` otherwise. Forced by load_tile_transpose's
  SubMinDim = 64-bit / sizeof(V) constraint: for FP8 SubMinDim=8 and
  kABKPerLane=8 only Single satisfies the validation static_asserts.
* GetAlignmentK / GetAlignmentV on gfx950 drop to 4 B/lane for fp8/
  bf8. The natural 16 B/lane async-load that BF16/FP16 use leaves
  NumIssues = 0 for the FP8 tile shapes we compile, and 8 B/lane
  fails the dword / dwordx3 / dwordx4 constraint in
  amd_buffer_addressing_builtins. 4 B/lane gives NumIssues >= 1 on
  every targeted variant and is the same alignment the gfx942
  fallback already used. BF16/FP16 keep the full 16 B/lane path so
  existing perf is unchanged.
* GetSmemSizeKV adds a `VLoadDescSize` lower bound. The
  MakeVLdsLoadBlockDescriptor's element span dominates the banked
  SingleVSize only for FP8 (small per-lane KVector + fixed
  kVLdsPadInBytes = 64), so without it FP8 hits the GetSmemSizeKV
  static_asserts. BF16/FP16 are unaffected.

Warp-gemm headers + dispatcher:
* New `WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed_T<AttrNumAccess>`
  template alias in warp_gemm.hpp (mirrors the existing BF16 32x32x16
  CTransposed template), used by the PV gemm to thread the FP8
  Single AttrNumAccess through.
* New Dispatcher specialization for
  <fp8_t, fp8_t, float, 32, 32, 16, true, false, false, EDouble>
  in warp_gemm_dispatcher.hpp routing to the new template.

ABI / dispatcher (unified_attention.{cpp,hpp}, unified_attention_impl.hpp):
* New `fp8` value in `unified_attention_args::data_type_enum` (selects
  e4m3fn on gfx950 via CK_TILE_USE_OCP_FP8, e4m3fnuz elsewhere).
* New `unified_attention_problem_traits<...::fp8>` alias:
  qkvp_dtype = ck_tile::fp8_t, acc_dtype = float, o_dtype = bf16_t
  (matches the Triton reference), lse_dtype = float.
* Per-tensor `q_descale` / `k_descale` / `v_descale` floats on
  `unified_attention_args` (default 1.0f so non-FP8 round-trips
  cleanly). The pipeline folds q_descale*k_descale into the softmax
  scale and applies v_descale once to o_acc after the 1/l norm --
  same semantics as Triton's q_scale/k_scale/v_scale.
* `dispatch_variant<>` enables FP8 on prefill_d{64,128},
  decode_d{64,128}_m128, decode_d128_m32, decode_d64_m64. The
  16x16x32 _m16 tiers return (false, -1.f) for now (see top comment).

Instances:
* 12 new FP8 .cpp files under example/.../42_unified_attention/
  instances/ covering the 6 enabled variants x {mask, nmask}.

Validation: 112 / 0 / 128 in the FP8 pytest sweep (passed / failed /
m16-skipped); 245 / 245 in the BF16/FP16 sweep (no regression).
Functional correctness is within the FP8 quant-noise tolerance the
Triton FP8 suite uses (atol/rtol = 1.5e-1). Perf still trails Triton
across the enabled tiers (CK FP8 / Triton FP8 = 0.39-0.69x on the
shapes we benchmarked); that's a separate workstream.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 17:34:50 +00:00
juuso-oskari
1f69421434 CK-UA: dispatch K/V async load on cache_ptr_int32_overflow_possible
The shared-SRD buffer_load_dword_lds path that K_mem_load / V_mem_load use
wraps the per-lane voffset (int32 bytes) once
  num_blocks * page_size * row_stride * sizeof(T) > INT32_MAX,
silently returning wrong data on large paged-KV pools (e.g. >4 GB caches).

Add a second path, async_load_tile_raw_long, that issues the same load via
__builtin_amdgcn_global_load_lds with per-lane 64-bit base pointers, lifting
both 4 GB limits (SRD size + voffset). Per-issue LDS pointers are computed
explicitly because the intrinsic sets m0 itself, so the old m0_set / m0_inc
bookkeeping doesn't apply. The path also clamps lane_elem_off to the live
buffer range to mimic the original SRD's hardware OOB behaviour.

Dispatch is a wave-uniform runtime branch on a new
cache_ptr_int32_overflow_possible flag plumbed from
unified_attention_args through MakeKargs into the pipeline operator().
Small caches keep the original buffer_load throughput; only the (rare)
>4 GB cache pays the global_load_lds cost.

k_page_offsets / v_page_offsets are widened to long_index_t. The original
buffer_load path implicitly narrows back to int32 when forwarding through
async_get_vectorized_elements_raw, which is intentional and safe whenever
the overflow flag is false.

For diagnostics, also derive a constexpr KWaveSpanInN =
(LaneGroups - 1) * NumWarps + 1 inside the pipeline; when this exceeds
page_size a single buffer_load spans multiple random pages, so the
per-issue SRD-rebase optimisation (not implemented yet) would not apply
even on a sub-4 GB cache. Informational only today.

Test: ua-test-scripts correctness sweep (245/245 pass), plus
  test_single_shape.py -b 32 -sq 8192 -sk 120000 -hq 64 -hk 8 -d 64 \
      --num-blocks 1200000 --block-size 16 --test
which previously returned wrong data due to the int32 wrap and now passes
with max abs diff 1.22e-04 vs Triton.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-15 09:00:43 +00:00
juuso-oskari
d77f0bea63 CK-UA: collapse MHA/GQA variants -- one binary per (head_dim, kBlockM)
After moving kBlockQ to runtime in the previous commit, the static
NumQPerKV in `variant_config<V>` and the runtime-vs-static assert in
the kernel became the only things still tying a compiled binary to a
specific num_queries_per_kv. Drop both and the existing instances now
serve every num_qpkv that divides kBlockM evenly.

Concretely:
  * `variant_config<V>` -- remove the NumQPerKV field from every
    specialization.
  * `unified_attention_kernel_traits` -- remove the `num_queries_per_kv`
    / `kBlockQ = kBlockM / num_qpkv` derivation. The BlockTile's 2nd
    entry (the static `kBlockQ` exposed via UnifiedAttentionShape) is
    anchored at kBlockM so it describes the "num_qpkv == 1" fallback;
    the actual kBlockQ is always the runtime value.
  * `unified_attention_kernel_launch` -- recompute kBlockQ at host time
    from `args.num_queries_per_kv` for the total_num_q_blocks math.
  * `unified_attention_kernel.hpp` -- drop the
    `assert(kBlockQ_dyn == kBlockQ)` (it enforced the very coupling we
    just removed).
  * `unified_attention.cpp::select_config` -- collapse the two
    per-num_qpkv code paths into a single (head_dim, avg_rows,
    max_rows) ladder, where avg_rows = avg_q * num_qpkv.

Variant renames (8 variants):
  prefill_d128_mha       -> prefill_d128
  decode_d128_mha_m128   -> decode_d128_m128
  decode_d128_mha_m32    -> decode_d128_m32
  decode_d128_mha_m16    -> decode_d128_m16
  prefill_d64_gqa8       -> prefill_d64
  decode_d64_gqa8_m128   -> decode_d64_m128
  decode_d64_gqa8_m64    -> decode_d64_m64
  decode_d64_gqa8_m16    -> decode_d64_m16

The 16 d=64 instance files lose their `_gqa8` infix to match the
d=128 naming (file count unchanged: 16 dtypes x mask combos per
head_dim).

Validation:
  * Correctness suite: 241/245 (same 4 pre-existing int32-overflow
    failures in the prefill rebased-pointer path).
  * d=128 GQA-8 (a NEW combo we never had a binary for) -- runs
    correctly on the existing decode_d128_m* binaries with num_qpkv=8
    at runtime. max abs diff <= 1e-2 vs the torch reference at ql in
    {1, 4, 16}.
  * d=64 MHA (also a new combo) -- runs correctly on the existing
    decode_d64_m* binaries with num_qpkv=1. Same tolerance.
  * Perf sweep (b=4..256, sk=120000, MI300):
      d=64  GQA-8: speedups 1.28x..1.84x vs Triton (within 0.6%
                   of baseline).
      d=128 MHA:   speedups 0.98x..1.14x vs Triton (within 0.3%
                   of baseline).

Unlocked: adding new (head_dim, num_qpkv) combos no longer requires
new kernel binaries -- just a host-side heuristic update mapping the
combo to the appropriate (kBlockM, BlockWarps) ladder.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 12:15:55 +00:00
juuso-oskari
f5beedb2e9 Add CK-UA decode_d128_mha_m32 / _m16 small-Q tiers
For pure-decode workloads (sq=1) at d=128 the m128 tile wastes most of
its 128 query rows, capping CK below Triton on every batch size in our
sweep (4..256). Add two small-Q tiers that mirror the d=64 GQA-8 ladder:

  * decode_d128_mha_m16 : kBlockM=16, 1 warp, 16x16 MFMA  (tiny-decode)
  * decode_d128_mha_m32 : kBlockM=32, 1 warp, 32x32 MFMA  (tiny-decode)

select_config now ladders by (avg_q, max_q): m16 -> m32 -> m128 -> prefill.

d=128 MHA, hq=16/hk=16, sq=1, sk=120k, num_blocks=60k:
  batch  before    after    CK BW
      4  ~0.95x   0.98x   4.76 TB/s
      8  ~0.85x   1.29x   5.00 TB/s
     32  ~0.85x   1.14x   5.29 TB/s
     64  ~0.75x   0.93x   5.35 TB/s
    128  ~1.00x   1.09x   5.39 TB/s
    256  ~1.03x   1.02x   5.41 TB/s

Correctness suite stays at 241/245 (same 4 known int32-overflow
failures in the prefill path).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 11:48:19 +00:00
juuso-oskari
fb0d729fbb Collapse CK-UA traits into single kernel_traits<V, DType, IsMask> template
Replace 4 near-identical *_kernel_traits classes (~400 lines of repeated
shape/policy plumbing) with one templated `unified_attention_kernel_traits`
parameterized by `KernelVariant V`. The 6 dispatch_<variant> helpers in
unified_attention.cpp collapse into a single `dispatch_variant<V>` function
template that fans out over (dtype, mask).

Per-variant compile-time knobs (BlockM, BlockSize, warp count, MFMA shape,
pipeline policy, decode-grid flag) now live in one variant_config<V>
specialization each. "What's different between variants" is readable
top-to-bottom in a single block of code, and each instance .cpp shrinks to
a one-line `INST_UNIFIED_ATTENTION_DISPATCH(V, DTYPE, IS_MASK)` macro.

No behavior change. Correctness suite: 236/240 (same 4 known
num_blocks=32768 + d=128 MHA int32-overflow failures as baseline).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 10:35:15 +00:00
juuso-oskari
5bd8f73a28 Delete CK-UA bs32 variant family
The bs32 variants existed because pre-fix the pipeline required
kBlockN <= page_size, so page_size=32 forced a kBlockN=32 kernel
family. The multi-page-tile fix (commit 473869aba) lifted that
constraint and made kBlockN compile-time-independent of the runtime
page size, so the bs32 family is now redundant: every non-bs32 variant
is correct for any page_size.

This was validated in advance by forcing use_bs32=false in the
dispatcher and running the full correctness suite -- 236/240, identical
to baseline (the 4 remaining failures are the pre-existing int32-
overflow case, orthogonal).

Removes:
  * 16 instances/unified_attention_*_bs32_*.cpp files
  * unified_attention_decode_bs32_kernel_traits in unified_attention_impl.hpp
  * 3 _BS32 dispatch macros in unified_attention.cpp
  * 3 _p32 entries from the KernelVariant enum
  * 3 dispatch_*_p32 helper functions and their switch cases
  * the page_blk_size branch in select_config (now a pure tile-tier ladder)

Net: 12 fewer compile units (build time -6s on JIT), 78 fewer dispatcher
lines, and "which kernel runs?" is now driven purely by Q-tile shape.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 09:41:41 +00:00
juuso-oskari
fddb0d21cd Add d=128 MHA decode variant (decode_d128_mha_m128)
Until now every d=128 MHA workload took the 8-warp prefill kernel
(kBlockM=256, kBlockQ=256), wasting 255/256 Q rows on pure-decode
shapes where Q is 1. Add a dedicated 4-warp decode variant with
kBlockM=128 (kBlockQ=128) that cuts the Q-tile waste roughly in half.

  * Four new instance files at instances/unified_attention_d128_*_decode.cpp,
    each instantiating unified_attention_decode_kernel_traits<dt, mask, 128, 128, 1>.
  * KernelVariant::decode_d128_mha_m128 wired into select_config: chosen
    when both avg_q and max_seqlen_q fit in 128, else fall back to prefill.

Tests: ua-test-scripts/test_unified_attention_ck_correctness.py stays at
236/240 -- the pure-decode seq_lens pattern in head_config=(16,16,128)
now routes to the new variant and matches the torch reference. The 4
remaining failures are the pre-existing int32-overflow case (orthogonal).

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 09:34:52 +00:00
juuso-oskari
3ab4df37e2 Refactor CK-UA dispatcher around KernelVariant + select_config
The previous dispatcher was a 4-deep nested-if cascade that picked one
of seven DISPATCH_* macros based on (hdim, num_queries_per_kv, dtype,
mask, tile_tier, use_bs32). The macro names hid both the traits class
and the dispatch path, so reasoning about "what kernel runs for shape
X" required reading the whole file.

Replace it with two named layers:

  1. KernelVariant enum -- a flat list of every compiled instance.
  2. select_config(args) -- the only place runtime decisions live;
     reads the problem and emits a KernelConfig{variant, ...}.

The final switch over the variant calls into per-variant dispatch
helpers that fan out over (dtype, mask) via the existing DISPATCH_*
macros. Behaviour is unchanged: each old (hdim, nqpkv, tier, p32) tuple
maps 1:1 to a KernelVariant, and the same instance is launched.

Follow-up commits in this series will:
  - add a dedicated d=128 MHA decode variant
  - delete the _p32 ("bs32") family now that the multi-page-tile fix
    in the pipeline makes kBlockN independent of page_size

Test: ua-test-scripts/test_unified_attention_ck_correctness.py
      stays at 236/240 (same 4 pre-existing int32-overflow failures).
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 09:27:59 +00:00
juuso-oskari
25364aa634 Add KV-segment parallelism to CK unified attention pipeline
End-to-end split-KV (FlashDecoding-style) for the CK unified attention
kernel. The host launches a single 3D grid with z == num_splits; each
CTA computes its KV-range slice and writes a normalized (o_acc, lse)
partial to FP32 workspaces, which the caller reduces into the final
output.

Pipeline changes:
- operator() returns ck_tile::make_tuple(o_acc, lse) instead of just
  o_acc. The masked-empty early-exit returns lse = -inf so downstream
  combine weighs the partial as zero.
- LSE is built in the natural-log domain from the pipeline's *unscaled*
  rowmax: lse = (scale_s / log2(e)) * m + log(l). Previously we used
  m / log2(e) + log(l), which dropped the per-head scale and produced
  LSE values ~1/scale too large.
- Fix post-process parity: which SP register is left in the
  alu0-done/alu1-pending state at loop exit depends on the parity of
  the *iteration count* (= num_total_loop - num_blocks_start), not on
  num_total_loop alone. For non-split (num_blocks_start == 0) the two
  parities coincide; for splits starting at an odd tile they don't.
- Fix split-KV page-table offset: num_blocks_start is counted in
  kPageBlockSize-sized tiles, but block_tables is indexed in
  page_size-sized pages — shifting block_table_offset by num_blocks_start
  reads the wrong pages whenever kPageBlockSize != page_size. Replaced
  with split_token_offset = num_blocks_start * kPageBlockSize added to
  logical_token before /page_size, so the page lookup uses the absolute
  token position.

Kernel + dispatcher:
- Drop kargs.i_split; each CTA reads i_split = blockIdx.z.
- GridSize{2D,Decode} now take num_splits and add it as the z-dim
  (defaults to 1, so non-split callers see dim3(..., 1, 1)).
- New write path: when num_splits > 1, the kernel skips the user
  epilogue and instead writes the FP32 (o_acc, lse) tile pair into
  workspace tensors at [head, split, batch_start_token, ...] using
  Default2DEpilogue (UseRawStore=true) for o_acc and store_tile for
  lse. Host strides come from kargs.

Co-authored-by: Cursor <cursoragent@cursor.com>
2026-05-12 08:42:09 +00:00
root
63821af1ff Add split-KV decode tiles (b16x32, b32x32) + fix num_splits heuristic
Decode tiles for split-KV hdim=64: bm0=16/1-warp and bm0=32/2-warp.
Fix num_splits to use num_heads_kv (not num_heads_q) and target 4x SMs.

Performance unchanged (0.056ms) because:
1. Split+combine overhead dominates for short KV (31 pages)
2. Triton 3D's single-kernel split avoids combine kernel entirely

Made-with: Cursor
2026-04-01 18:49:16 +00:00
root
c5600bc8ae Add decode tiles (b16x32, b32x32) to pagedkv_prefill codegen with max_seqlen_q dispatch
Made-with: Cursor
2026-04-01 18:30:06 +00:00
root
65a3f88ad8 Fix CK-UA mixed batch: use max_seqlen_q for tier selection
Decode grid (num_kv_heads, num_seqs) assumes each seq has <= kBlockQ
tokens. For mixed batches (decode + prefill), avg_q is low but some
seqs have hundreds of tokens, causing truncation. Added max_seqlen_q
to args and check it in select_tile_tier to force medium tier (1D
grid with Q tile iteration) for mixed batches.

362/362 no-window shapes now pass.

Made-with: Cursor
2026-04-01 18:09:48 +00:00
root
07ba03bcbf Fix sliding window mask: use window_generic when left >= 0
mask_info::decode('b:left,right,sink') always created mask_bottom_right
(IsLocal=false) which ignores the left window boundary. For sliding
window attention (left >= 0), use window_generic (IsLocal=true) so the
kernel respects the left boundary.

Fixes: CK split-KV producing identical results with and without sliding
window. Now 724/724 shapes pass correctness vs Triton.

Made-with: Cursor
2026-04-01 18:00:19 +00:00
root
e5272603c9 Wire FmhaFwdPagedKV: enable bf16 hdim=64 with bn0=32 for page_block_size=32
Made-with: Cursor
2026-04-01 17:18:41 +00:00
root
10564b0c40 Enable FmhaFwdPagedKV bf16 hdim=64 instances (was commented out)
Made-with: Cursor
2026-04-01 16:49:20 +00:00
root
cd7ba6e2e8 Add unified attention (42_unified_attention)
Squashed from aghamari/unified-attention-decode-opt branch.

CK tile paged-KV attention kernel optimized for decode with 4-tier
dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid,
head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with
block_size=32.

Made-with: Cursor
2026-04-01 16:39:15 +00:00
root
cb6fb2802d Split-KV codegen: dual-tile dispatch and head-merge for hdim=64
1. Dual-tile: add both bn0=64 (preferred) and bn0=32 (fallback) for
   hdim=64 on gfx9 and gfx12. The dispatch checks page_block_size %
   bn0 == 0 at runtime to select the optimal tile. bn0=64 halves KV
   iterations when page_block_size >= 64.

2. Tile dict now supports lists per hdim. The codegen loop iterates
   over all tile variants, generating separate kernel instances for
   each. Combine kernels are unaffected (tile-independent).

3. Enable kMergeNumHeadGroupsSeqLenQ for hdim=64 decode (previously
   hdim=128 only). For GQA-8 with max_seqlen_q=1, this packs 8 head
   groups into the M dimension. Only activates for no-mask instances
   (kernel static_assert requires !kHasMask).

4. Add qr (non-async) pipeline for fwd non-bias group mode as
   fallback after qr_async. The async pipeline on this branch has a
   kernel-level bug where fmha_fwd launches but writes no output.

Made-with: Cursor
2026-04-01 16:24:25 +00:00
root
6729989b97 Fix FMHA split-KV for paged-KV with page_block_size < kN0
Cherry-picked from aghamari/unified-attention-decode-opt (fadf0d585).
- block_masking.hpp: 5-param GetTileRangeAlongX for GenericAttentionMask
- fmha_fwd_splitkv.py: bn0=32 for hdim=64

Made-with: Cursor
2026-04-01 16:24:19 +00:00
root
4c5e290378 Add unified attention (42_unified_attention) and topk_softmax_decode
Squashed from aghamari/unified-attention-decode-opt branch.

42_unified_attention: CK tile paged-KV attention kernel optimized for
decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA,
2D decode grid, head-group merging. Supports hdim=64 GQA-8 and
hdim=128 MHA with block_size=32.

topk_softmax_decode: fused topk + softmax kernel for M=1 MoE decode.

Made-with: Cursor
2026-04-01 16:24:04 +00:00
Fu-Cheng Tsai
a502e5a00b [rocm-libraries] ROCm/rocm-libraries#5798 (commit 7acd4e7)
[CK_TILE] Update gfx12 FMHA forward kernel configs
2026-04-01 14:23:38 +00:00
Hosang Yoon
2dcae9d173 [rocm-libraries] ROCm/rocm-libraries#5977 (commit 794bea7)
[CK_TILE] Fix Windows build in FMHA head grouping

## Motivation

This is a follow-up fix for [PR
#5018](https://github.com/ROCm/rocm-libraries/pull/5018).

[PR #5018](https://github.com/ROCm/rocm-libraries/pull/5018) added
LLC-aware FMHA head grouping / head-major scheduling on RDNA, but it
also introduced Linux-only code paths, including `<dirent.h>`, which
break Windows builds. This change fixes that by guarding the
Linux-specific LLC probing logic so non-Linux platforms can still build
correctly.

## Technical Details

- Guard `<dirent.h>` with `#ifdef __linux__`
- Guard KFD sysfs traversal logic with `#if defined(__linux__)`
- On non-Linux platforms, return `0` from
`get_kfd_sysfs_llc_cache_bytes()`
- Preserve existing fallback behavior through:
  - `CK_TILE_FMHA_LLC_CACHE_MB`
  - arch-based default LLC sizes
  - no head grouping when no LLC size can be resolved

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-30 14:19:19 +00:00
Jeff Huang
7968368d92 [rocm-libraries] ROCm/rocm-libraries#5918 (commit a7e2c67)
[CK][CK_TILE] Add fp8bf16 hdim=256 tile for batch prefill
 (#5918)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
FP8 batch prefill kernels currently only support head_dim=128. Models
with head_dim=256 hit the "invalid argument for batch_prefill" error
because no matching kernel variant exists in the codegen dispatch.

## Technical Details
Add a hdim=256 tile size entry for fp8bf16 in the batch prefill codegen
recipe (`fmha_batch_prefill.py`).

Tile configuration: `FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4,1,1,
4,1,1, 32,32,32, 32,32,32, -1)`
- bm0=128, bn0=128 (Q/K tile sizes)
- bn1=256, bk0max=256 (V head_dim=256)
- Warp MFMA 32x32x32 (fp8 MFMA instructions)

This mirrors the existing bf16/fp16 hdim=256 tile but uses fp8 warp
sizes.

## Test Plan
Tested on both MI308X (gfx942) and MI355X (gfx950) via aiter batch
prefill test with the following matrix:
- page_size: {1, 16, 1024}
- kv_layout: {linear, vectorized}
- lookup_table: {sglang, vllm}
- causal: {true, false}
- logits_soft_cap: {0.0, 30.0}
- contiguous_kv: {true, false}

## Test Result

**MI308X (gfx942):** 160 passed, 32 skipped (page_size=1 + vectorized
not applicable)
**MI355X (gfx950):** 120 passed, 72 skipped (pre-existing ROCm 7.2
compiler issue with causal + no softcap)

No register spills on either platform.

### Profiling — MI355X (gfx950), FP8 pertensor, hdim=256, seqlen=1024, 8
heads

| page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS |
|---------|-----------|-------|--------|----------|---------|--------|
| 1 | linear | sglang | False | 0.00 | 55.01 | 156.16 |
| 1 | linear | vllm | False | 0.00 | 55.12 | 155.84 |
| 1 | linear | sglang | False | 30.00 | 62.63 | 137.16 |
| 1 | linear | vllm | False | 30.00 | 62.16 | 138.20 |
| 1 | linear | sglang | True | 30.00 | 64.09 | 67.01 |
| 1 | linear | vllm | True | 30.00 | 63.85 | 67.27 |
| 16 | linear | sglang | False | 0.00 | 57.00 | 150.69 |
| 16 | vectorized | sglang | False | 0.00 | 57.55 | 149.25 |
| 16 | linear | vllm | False | 0.00 | 56.80 | 151.23 |
| 16 | vectorized | vllm | False | 0.00 | 57.32 | 149.87 |
| 16 | linear | sglang | False | 30.00 | 64.77 | 132.62 |
| 16 | vectorized | vllm | False | 30.00 | 63.54 | 135.18 |
| 16 | linear | sglang | True | 30.00 | 66.84 | 64.26 |
| 16 | vectorized | vllm | True | 30.00 | 66.12 | 64.96 |
| 1024 | linear | sglang | False | 0.00 | 58.25 | 147.46 |
| 1024 | vectorized | sglang | False | 0.00 | 57.53 | 149.31 |
| 1024 | linear | vllm | False | 0.00 | 58.06 | 147.94 |
| 1024 | vectorized | vllm | False | 0.00 | 57.55 | 149.27 |
| 1024 | linear | sglang | False | 30.00 | 65.38 | 131.38 |
| 1024 | vectorized | vllm | False | 30.00 | 63.64 | 134.98 |
| 1024 | linear | sglang | True | 30.00 | 66.85 | 64.25 |
| 1024 | vectorized | vllm | True | 30.00 | 65.26 | 65.81 |

### Profiling — MI308X (gfx942), FP8 pertensor, hdim=256, seqlen=1024, 8
heads

| page_sz | kv_layout | table | causal | soft_cap | time_us | TFLOPS |
|---------|-----------|-------|--------|----------|---------|--------|
| 1 | linear | sglang | False | 0.00 | 110.18 | 77.96 |
| 1 | linear | vllm | True | 30.00 | 134.33 | 31.97 |
| 1 | linear | sglang | True | 30.00 | 134.59 | 31.91 |
| 16 | linear | sglang | False | 0.00 | 115.43 | 74.42 |
| 16 | vectorized | sglang | False | 0.00 | 106.11 | 80.95 |
| 16 | linear | vllm | False | 0.00 | 116.34 | 73.83 |
| 16 | vectorized | vllm | False | 0.00 | 106.17 | 80.91 |
| 16 | linear | sglang | False | 30.00 | 135.61 | 63.34 |
| 16 | vectorized | vllm | False | 30.00 | 122.37 | 70.20 |
| 16 | linear | sglang | True | 0.00 | 117.44 | 36.57 |
| 16 | vectorized | vllm | True | 0.00 | 108.81 | 39.47 |
| 16 | linear | sglang | True | 30.00 | 139.43 | 30.80 |
| 16 | vectorized | vllm | True | 30.00 | 125.87 | 34.12 |
| 1024 | linear | sglang | False | 0.00 | 110.65 | 77.63 |
| 1024 | vectorized | sglang | False | 0.00 | 101.70 | 84.46 |
| 1024 | linear | vllm | False | 0.00 | 111.71 | 76.89 |
| 1024 | vectorized | vllm | False | 0.00 | 101.55 | 84.59 |
| 1024 | linear | sglang | False | 30.00 | 129.33 | 66.42 |
| 1024 | vectorized | vllm | False | 30.00 | 120.95 | 71.02 |
| 1024 | linear | sglang | True | 0.00 | 112.26 | 38.26 |
| 1024 | vectorized | vllm | True | 0.00 | 103.02 | 41.69 |
| 1024 | linear | sglang | True | 30.00 | 133.73 | 32.12 |
| 1024 | vectorized | vllm | True | 30.00 | 124.75 | 34.43 |

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-30 10:21:33 +00:00
Johannes Graner
58475d3f45 [rocm-libraries] ROCm/rocm-libraries#5393 (commit d51b649)
[CK Tile] StreamK support for Bwd Weight grouped convolutions
 (#5393)

## Motivation

Add StreamK work distribution to the CK Tile grouped convolution
backward weight kernel. Split-K divides the K-dimension uniformly across
a fixed `k_batch`, which causes load imbalance when the number of output
tiles doesn't evenly fill the GPU. StreamK distributes total
K-iterations evenly across workgroups, improving utilization on these
shapes.

## Technical Details

StreamK is added as an `if constexpr` branch in the existing kernel,
selected by the `TilePartitioner_` template parameter. Two reduction
strategies are supported:
- **Linear**: tile-starter sequentially accumulates partials from
contributing CTAs
- **Tree**: pairwise binary tree reduction (O(log n) depth, faster for
many contributors)

Both persistent and non-persistent data-parallel (DP) sections are
supported.

Key changes:
- `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution
path with `RunStreamK`/`RunStreamKLoop`, partial store/load via
workspace, flag-based cross-CTA synchronization,
`GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions
- `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers)
and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by
both GEMM and Conv StreamK kernels
- `streamk_gemm_kernel.hpp`: Refactored to use shared helpers
- Merged split-K and StreamK example invokers via `PartitionerPolicy`
template parameter
- StreamK example binary with `--streamk_reduction=linear|tree` and
`--streamk_persistent=0|1`
- CK Builder integration: `SpecifiesStreamK` concept,
`TilePartitionerType` factory helper, `InstanceTraits` with StreamK
fields
- 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP),
negative, builder regression

### Performance (MI355X, gfx950)

Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}):

| Shape | 16x64 tiles | | 128x128 tiles | |
|---|---|---|---|---|
| | Split-K | StreamK | Split-K | StreamK |
| 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x |
| 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x |
| 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x |
| 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x |
| 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x |
| 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x |
| 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x |
| 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x |

StreamK's value depends on tile config: with larger tiles (fewer output
tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up
to 1.16x on typical large-channel convolutions. Tree reduction
consistently outperforms Linear when multiple CTAs contribute to the
same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n)
sequential accumulation. The table reports the best of Linear and Tree
for each shape.

## Test Plan

```bash
ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk
./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk

# Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON)
ninja -C build check-builder
```

30 tests covering:
- Host-side: type traits, kernel args construction, grid size, workspace
size
- GPU end-to-end (Linear + Tree): small/medium shapes, multi-group,
stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher
occupancy
- Persistent DP: Linear + Tree with persistent data-parallel dispatch
- Negative: `IsSupportedArgument` rejects unaligned K and C
- Builder: Create (instance string validation) + Execution (reference
comparison) + instance string regression

## Test Result

All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK
tests pass. Full `check-builder` suite passes. Tolerances computed
dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware).

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-27 09:18:14 +00:00
yinglu
a268a2a2e1 [rocm-libraries] ROCm/rocm-libraries#5612 (commit 38c9498)
[CK]fix: remove redundant structured sparsity check in
 run_gemm_example.inc (#5612)

## Motivation

This issue if found via
https://github.com/ROCm/rocm-libraries/pull/4302#discussion_r2958603418
and is introduced via https://github.com/ROCm/rocm-libraries/pull/5323.

## Technical Details

The outer `if` and inner `if constexpr` both checked
GemmConfig::UseStructuredSparsity. Merged into a single `if constexpr`
since both preshuffle and UseStructuredSparsity are compile-time
constants.

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-20 08:23:07 +00:00
yinglu
d460ab35b6 [rocm-libraries] ROCm/rocm-libraries#4302 (commit e62bd8a)
[CK_TILE] add tf32 support
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Proposed changes

TF32 is added in CK on gfx942 and gfx950. This PR is to initiate tf32 in
CK_TILE on gfx942 and gfx950.

## Checklist

Please put an into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [ ] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [ ] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [ ] I have added inline documentation which enables the maintainers
with understanding the motivation
- [ ] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run  on all changed files
- [ ] Any dependent changes have been merged

## Discussion
2026-03-19 09:19:06 +00:00
Thomas Ning
5f90f69795 [rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-17 18:58:56 +00:00
Hosang
859acb5ae7 [rocm-libraries] ROCm/rocm-libraries#5018 (commit b32e7e6)
[CK_TILE] Add LLC-aware FMHA head grouping and head-major
 scheduling on RDNA (#5018)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation
Long-sequence FMHA can become memory-bound when K/V working sets exceed
Infinity Cache (LLC), causing repeated DRAM traffic across heads.

This PR introduces LLC-aware launch ordering improvements for FMHA
forward, and it is currently enabled only on gfx11 and gfx12. The
approach is inspired by
[`Dao-AILab/flash-attention#2217`](https://github.com/Dao-AILab/flash-attention/pull/2217),
adapted to CK’s kernel/runner structure and layout handling.

In this context, `bshd` is the layout used in Flash-Attention, while
`bhsd` is the default layout used by the CK Tile FMHA example.

## Technical Details
This PR adds two complementary strategies:

- For `bshd` input layout (`i_perm/o_perm=0`), enable explicit LLC-aware
head grouping:
  - Estimate LLC size (env override, KFD sysfs, or arch default).
  - Compute group size from K/V bytes per head vs LLC target.
- Launch FMHA forward repeatedly per head-group by slicing Q/K/V/O (and
related tensors).

- For `bhsd` input layout (`i_perm/o_perm=1`), apply implicit
launch-order adjustment:
  - Keep a single kernel launch.
- Reinterpret block linearization in `GetTileIndex` to make execution
head-major,
     improving temporal locality of per-head K/V reuse.

Additional integration updates:
- Propagate `num_head_q_total` and `head_start` through FMHA args/kargs.
- Use global head indexing for dropout RNG stream mapping so grouped
launches keep
    deterministic/consistent dropout behavior.
- Keep fallback behavior unchanged when grouping is not beneficial or
disabled.

## Test Plan
- `test_ck_tile_fmha`
- `tile_example_fmha_fwd`

## Test Result
- `test_ck_tile_fmha`: all tests passed.
- `tile_example_fmha_fwd`: tested this on gfx1100, gfx1151, and gfx1201,
and all of them show higher performance compared to the baseline. The
improvement is consistent, and performance is well maintained even at
long sequence lengths.

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode=0 -b=1 -h=24 -d=128
-s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
- TFLOPs by sequence length target: gfx1100 layout: bhsd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 56.27 | 61.48 | 1.09x
4096 | 67.10 | 72.27 | 1.08x
8192 | 65.99 | 71.64 | 1.09x
12288 | 61.60 | 76.61 | 1.24x
16384 | 58.99 | 75.74 | 1.28x
20480 | 57.32 | 74.42 | 1.30x
24576 | 56.89 | 74.25 | 1.31x
27280 | 18.93 | 24.48 | 1.29x

- TFLOPs by sequence length target: gfx1201 layout: bshd

SeqLen | Before | After | Speedup
-- | -- | -- | --
1024 | 66.79 | 65.90 | 0.99x
4096 | 85.90 | 86.80 | 1.01x
8192 | 77.06 | 90.29 | 1.17x
12288 | 58.36 | 88.98 | 1.52x
16384 | 52.12 | 88.88 | 1.71x
20480 | 48.11 | 88.42 | 1.84x
24576 | 47.12 | 89.07 | 1.89x
27280 | 49.05 | 50.31 | 1.03x

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-16 21:19:23 +00:00
Enrico Degregori
eb033ef208 [rocm-libraries] ROCm/rocm-libraries#4964 (commit 3271d9a)
[CK Tile] Eight Waves pipeline GEMM

## Motivation

Eight waves pipeline was added for ABQuant. The goal of this PR is to
enable it also for GEMM

## Technical Details

Summary:
 - Block:
- Create block struct for GEMM using eight warps specific distribution
encodings
   - Use this block struct in ABQuant for encodings
 - Pipeline:
- Create impl pipeline for eight waves which can be used by GEMM and
ABQuant as base (and for AQuant and BQuant in the future)
- Create eight waves pipeline for GEMM (this can not be easily
integrated in the existing async pipeline)
 - Pipeline policy:
- Extract GEMM specific parts in the ABQuant policy to define GEMM
policy (then ABQuant use it as base and add Quant specific methods)
- Minor: naming was inconsistent between warp/wave, everything is now
referred to as eight waves

So overall we have:
- block struct directly used by GEMM -> ABQuant derived struct to
implement operator
- Impl base pipeline with general implementation -> GEMM and ABQuant
pipelines use it to avoid code duplication but still define their own
pipelines
- pipeline policy struct directly used by GEMM -> ABQuant derived policy
struct for Quant specific parts

## Test Plan

Added new tests for GEMM pipeline:
`test_ck_tile_gemm_pipeline_comp_async_eight_waves` (only gfx950
supports it).

Note: K padding test is disabled for this pipeline because it's not
implemented yet

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-16 08:31:56 +00:00
Bartłomiej Kocot
b8108662da [rocm-libraries] ROCm/rocm-libraries#5387 (commit 0c259bd)
[CK][CK Tile] Grouped Convolution Backward Weight set of
 fixes (#5387)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Grouped Convolution Backward Weight split k fixes for CK tile kernels

## Technical Details

- get k batch from kargs to get deduced k batch
- multiply zeroing size by data type size
- disable v6 (producing a incorrect results)

## Test Plan

test_grouped_convnd_bwd_weight_tile

## Test Result

Pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-13 16:19:50 +00:00
Yi DING
574c1c121a [rocm-libraries] ROCm/rocm-libraries#5174 (commit a358a21)
[CK_TILE] FMHA BWD Use Persistent Kernels in Deterministic
 Mode (#5174)

## Motivation
This PR enables a persistent-kernel execution path for FMHA backward
(dQ/dK/dV) in deterministic mode, adjusting how dQ accumulation is
split, stored, and converted back to final gradients.

## Technical Details
- Introduces a persistent-kernel grid mapping in deterministic mode and
updates split-count calculation accordingly.
- Extends kernel kargs to carry batch-related info needed for persistent
scheduling and dQ conversion.
- Refactors dQ store conditions and adds mask-type traits/utilities and
runner logging updates.

## Test Plan
- Jenkins
[base](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/10/pipeline)
- Jenkins
[AITER](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/12/pipeline)
- Jenkins
[FMHA](http://micimaster.amd.com/blue/organizations/jenkins/rocm-libraries-folder%2FComposable%20Kernel/detail/PR-5174/11/pipeline)
- local FA tests

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-13 06:14:31 +00:00
Michał Kulikowski
2c3f9bfa52 [rocm-libraries] ROCm/rocm-libraries#5348 (commit 7b18234)
[CK][Examples] Adding parameters for a couple of CK examples:
 -gemm_add_add_mean_meansquare_xdl_fp16 -gemm_dl_quantization_int8
 -gemm_xdl_bias_relu_quantization_int8 -gemm_xdl_quantization_int8

Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
2026-03-12 08:48:36 +00:00
Aviral Goel
1a4aa7fd89 [rocm-libraries] ROCm/rocm-libraries#5082 (commit 9313659)
ck_tile: add gtest unit tests for MX flatmm (gfx950)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

- Add correctness unit tests for the MX-format flatmm kernel
(`example/ck_tile/18_flatmm/mxgemm`) under `test/ck_tile/flatmm/`
- Tests cover all five dtype combinations: FP4×FP4, FP8×FP8, FP6×FP6,
FP8×FP4, FP4×FP8
- Tests cover all four kernel dispatch paths (the `has_hot_loop` ×
`tail_num` product):
  - `has_hot_loop=false, tail=ODD` (K=256, num_loop=1)
  - `has_hot_loop=false, tail=EVEN` (K=512, num_loop=2)
  - `has_hot_loop=true, tail=ODD` (K=768, num_loop=3)
  - `has_hot_loop=true, tail=EVEN` (K=1024, num_loop=4)
- Remove unsupported `-split_k` CLI option from
`tile_example_mx_flatmm`; the pre-shuffled B layout is incompatible with
K-splitting and the option silently produced wrong results

## Changes

**New files (`test/ck_tile/flatmm/`):**
- `CMakeLists.txt` — builds 40 kernel instances as a shared OBJECT
library, links into 5 per-dtype test executables; forwards
`-DCK_TILE_USE_OCP_FP8` when `CK_USE_OCP_FP8` is ON
- `test_mx_flatmm_base.hpp` — base test fixture with
`run_test_with_validation(M, N, K, kbatch=1)`
- `test_mx_flatmm_fixtures.hpp` — concrete `TestMXFlatmm` typed test
class and type aliases
- `test_mx_flatmm_fp{4fp4,8fp8,6fp6,8fp4,4fp8}.cpp` — per-dtype
`TYPED_TEST_SUITE` files

**Modified files:**
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm_arch_traits.hpp` — moved
`preShuffleWeight` here (was in `mx_flatmm.cpp`) so it is includeable by
both the example and the tests
- `example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp` / `run_mx_flatmm.inc`
— removed `-split_k` CLI arg, hardcoded `k_batch=1`, fixed `k_split`
formula, updated call sites after `preShuffleWeight` move
- `test/ck_tile/CMakeLists.txt` — added `add_subdirectory(flatmm)`
2026-03-11 22:47:59 +00:00
Anton Gorenko
2312eef6c3 [rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on
 gfx950 (#4368)

## Motivation

Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline

## Technical Details

The microscaling is used when quant scale mode is
`BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are
fp8/bf8/fp4.

Supported features:
* only "qr" pipeline is implemented
* hdim 128 and 256 (smaller hdim are not possible due to restrictions of
"qr" pipeline, but they can be computed using instances with padding)
 * both 32x32x64 and 16x16x128 scale MFMAs are supported
 * Q and K scales are applied in hdim, V scales - in seqlen dimension
 * column-major V only
 * batch and group mode
 * bias, Alibi (tested but no instances by default, just like fp8)
 * masking etc.

Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008

## Test Plan

```
ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8
ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4
```

## Test Result

The tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-11 10:00:52 +00:00
Sami Remes
8f27f65d44 [rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-10 20:12:43 +00:00
Hosang
c800f88911 [rocm-libraries] ROCm/rocm-libraries#5088 (commit 36ca523)
[CK_TILE] Update gfx11 FMHA forward kernel configs

## Motivation
Tune gfx11 FMHA codegen to recover performance for mainly PSSK (padded
seqlen_q/k) cases.
This tuning is based on heuristic search and improves performance in
most tested shapes.
Performance should be evaluated on top of
[`ROCm/rocm-libraries#5018`](https://github.com/ROCm/rocm-libraries/pull/5018)
(required baseline).

## Technical Details

  - Updated gfx11 codegen heuristic choices for tile size and occupancy.
   - Updated gfx11 pipeline selection:
- Disabled the `npad` (`f,f,f,f`) qr entry because it was consistently
slower than the `pssk` (`t,t,f,f`) path, and kept `pssk` enabled so npad
cases are dispatched to the faster kernel path.`
- Kept gfx12 unchanged: with PSSK support from
[`ROCm/rocm-libraries#4957`](https://github.com/ROCm/rocm-libraries/pull/4957),
existing gfx12 config is already sufficient.
  - Tuning rationale:
    - In some cases, higher `kBlockPerCu` lowers register pressure.
- On RDNA, this generally aligns with better performance when
`waves_per_eu >= 6`.

## Test Plan
- test_ck_tile_fmha
- tile_example_fmha_fwd: tested this on gfx1100 and gfx1151
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=24
-d=128 -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result
- TFLOPs by sequence length target: `gfx1100` layout: `bhsd`
- mode: batch / VGPR usage: 225 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 74.10 | 71.97 | 0.97x
4096 | 66.26 | 77.79 | 1.17x
8192 | 68.18 | 75.88 | 1.11x
12288 | 68.47 | 80.44 | 1.17x
16384 | 59.54 | 79.66 | 1.34x
20480 | 55.78 | 77.91 | 1.40x
24576 | 55.08 | 77.47 | 1.41x
27280 | 47.45 | 77.16 | 1.63x
- mode: group / VGPR usage: 256 vs 214

SeqLen | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 71.47 | 70.6 | 0.99x
4096 | 64.74 | 77.06 | 1.19x
8192 | 64.68 | 75.47 | 1.17x
12288 | 66.43 | 79.95 | 1.20x
16384 | 56.02 | 79.73 | 1.42x
20480 | 50.21 | 78.15 | 1.56x
24576 | 47.29 | 77.53 | 1.64x
27280 | 46.13 | 77.04 | 1.67x

- TFLOPs by sequence length target: `gfx1151` layout: `bshd`
- mode: batch / VGPR usage: 225 vs 223

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 26.85 | 29.17 | 1.09x
4096 | 24.75 | 26.01 | 1.05x
8192 | 25.24 | 25.50 | 1.01x
12288 | 25.18 | 25.00 | 0.99x
16384 | 24.79 | 25.91 | 1.05x
20480 | 25.56 | 25.24 | 0.99x
24576 | 25.13 | 26.20 | 1.04x
27280 | 10.78 | 26.35 | 2.44x
- mode: group / VGPR usage: 256 vs 229

Batch | Baseline | Tuned | Gain
-- | -- | -- | --
1024 | 27.44 | 26.71 | 0.97x
4096 | 21.89 | 23.09 | 1.05x
8192 | 22.85 | 24.49 | 1.07x
12288 | 24.33 | 24.42 | 1.00x
16384 | 20.05 | 24.98 | 1.24x
20480 | 14.70 | 25.15 | 1.71x
24576 | 11.30 | 26.31 | 2.33x
27280 | 10.10 | 26.32 | 2.61x

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-10 16:47:43 +00:00
kensclin
8c216604d4 [rocm-libraries] ROCm/rocm-libraries#5218 (commit 60156cf)
[CK] Fix the issue of the aiter to call eightwarps pipeline.
 (#5218)

## Motivation

Fix the failure of the aiter to call eightwarp.
Changed Async to the name eightwarps.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

Pass

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-09 18:13:07 +00:00
Christopher Millette
e0d11b969b [rocm-libraries] ROCm/rocm-libraries#5030 (commit 8e02a26)
[CK] Replace tuple value construction with tuple_element_t
 type extraction [1A] (#5030)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

### Rationale
CK's device operation instance registration uses
`add_device_operation_instances` at ~1,850
call sites to register GPU kernel configurations. The existing
implementation constructs
`std::tuple` values just to extract their types via `decltype`, then
copy-constructs each
instance into `make_unique`. This is wasteful — only the types matter,
not the values — and
forces the compiler to instantiate the full `std::tuple` constructor and
`std::get` machinery
at every call site.

### What changed
- Replace `remove_cvref_t<decltype(std::get<i>(tuple_obj))>` with
`std::tuple_element_t<i.value, TupleType>`, which extracts the type
directly without constructing any values
- Replace copy-from-default `make_unique<T>(value)` with direct default
construction `make_unique<T>()` — all CK device operation instances are
stateless structs with configuration encoded in template parameters
- Add `static_assert(std::is_default_constructible_v<NewOpInstance>)` to
enforce this contract at compile time with a clear error message
- Add Doxygen documentation for this high-traffic public API

### Value
- Eliminates unnecessary template instantiation of `std::tuple`
constructors and `std::get` across ~1,850 call sites
- Establishes a cleaner, more intention-revealing pattern for type-only
tuple usage
- The `static_assert` prevents silent breakage if a
non-default-constructible type is ever added
- No runtime behavior change — zero risk

### Files changed (9)
- `add_device_operation_instance.hpp`: Core pattern change
- 3 example files, 3 reduce instance headers, 1 convolution header, 1
profiler header

## Test plan
- [ ] Existing CI tests cover all ~1,850 call sites (GEMM, reduce,
softmax, convolution)
- [ ] `static_assert` provides compile-time validation stronger than
runtime tests
- [ ] No runtime behavior change — stateless struct default construction
is identical to copy-from-default
- [ ] Compatible with both `std::tuple` and `ck::type_list` containers

🤖 Generated with [Claude Code](https://claude.com/claude-code)
## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-06 16:28:22 +00:00
Ville Pietilä
ae4e632c7d [rocm-libraries] ROCm/rocm-libraries#4797 (commit 1a30400)
[CK_TILE] Add CK Tile bwd weight profiler
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

To compare old CK and CK Tile, we need to extend the current CK profiler
to support running also CK Tile instance with the same API. In order to
have the same instance coverage in CK Tile compared to the old CK, I've
added code generation from old CK configurations to CK Tile instances
using the CK Builder.

## Technical Details

- The codegen python script for CK Tile fwd convs is extended to support
also bwd weight and bwd data.
- The generated instances are added to the CMake build (target
`device_grouped_conv_bwd_weight_tile_instance`s).
- A new profiler op (`grouped_conv_bwd_weight_tile`) has been added to
the CK Profiler.
2026-03-04 21:50:29 +00:00
Anton Gorenko
782f0a9eed [rocm-libraries] ROCm/rocm-libraries#4957 (commit d2e4eb8)
[CK_TILE][FMHA] Extend pipelines with pssk for gfx11/12
 (#4957)

## Motivation

Build pipelines with seqlen padding only to support vectorized loads in
the hdim dimension.
The existing pipelines have either all dims padded or all dims not
padded.
These pipelines can be used in ComfyUI for slightly better performance.

## Technical Details

Also a fix included for correct FLOPS calculation in
`tile_example_fmha_fwd` when `seqlen_q * seqlen_k` overflows index_t
capacity (signed int32).

## Test Plan

The existing test cases will use the new pipelines when parameters allow
(seqlens - padded, hdims - not padded):
```
ninja test_ck_tile_fmha_fwd

bin/test_ck_tile_fmha_fwd_fp16
bin/test_ck_tile_fmha_fwd_bf16

bin/test_ck_tile_fmha_fwd_fp8bf16 # for gfx12
```

## Test Result

All tests must pass.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-04 04:50:44 +00:00
Yi DING
b09112bbad [rocm-libraries] ROCm/rocm-libraries#4577 (commit a36922c)
[CK_TILE] FMHA BWD Launcher Interface

## Motivation
Reduce memory usage; Be prepared to implement optimizations of reducing
nsplits in deterministic cases.

## Technical Details
This PR introduces a new launcher interface for the FMHA backward
operation, replacing direct function calls with a more structured
approach. The launcher encapsulates kernel dispatch logic and provides
access to computed metadata like the number of dQ acc splits.

**Changes:**
- Added `fmha_bwd_launcher` class that wraps kernel execution and
exposes `dq_acc_splits`
- Moved `fmha_bwd_traits` construction earlier in the execution flow to
support launcher initialization
- Refactored code generation to produce both legacy API and new launcher
constructor

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-04 01:21:07 +00:00
Brock Hargreaves
2a16d53cce [rocm-libraries] ROCm/rocm-libraries#5045 (commit 64a5502)
[CK] Address a bunch of errors associated with targeting
 gfx1200 on Windows (#5045)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Still addressing errors that are blocking the merge of TheRock PR:
https://github.com/ROCm/TheRock/actions/runs/22545831304/job/65308264096?pr=3382

## Technical Details

1. There are multiple fmha python scripts that are writing native paths
which are confusing cmake. I addressed one of these in an earlier PR
https://github.com/ROCm/rocm-libraries/pull/4812 and now I'm addressing
more that are exposed with gfx1200 target:

```
[composable_kernel configure] CMake Error at example/ck_tile/50_sparse_attn/CMakeLists.txt:61 (add_library):
[composable_kernel configure]   Syntax error in cmake code when parsing string
[composable_kernel configure]
[composable_kernel configure]     B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp
[composable_kernel configure]
[composable_kernel configure]   Invalid character escape '\b'.
```

2. In the following compiler error we see gemm_prec_str<ADataType,
BDataType> being passed as a function to concat(...), instead of being
evaluated with the parenthesis operator(), i.e.,
gemm_prec_str<ADataType, BDataType>(). There are multiples instances of
this, I wonder what non-msvc compilers do here:

```
[composable_kernel] FAILED: [code=1] example/ck_tile/38_block_scale_gemm/CMakeFiles/tile_example_gemm_quant.dir/gemm_bquant_quantgrouped_mx_bf16bf8.cpp.obj
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp:4:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/example/ck_tile/38_block_scale_gemm\run_gemm_quant_example.inc:17:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:7:
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/concat.hpp:119:21: error: implicit conversion between pointer-to-function and pointer-to-object is a Microsoft extension [-Werror,-Wmicrosoft-cast]
[composable_kernel]   119 |     ((oss << sep << rest), ...);
[composable_kernel]       |                     ^~~~
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp:248:16: note: in instantiation of function template specialization 'ck_tile::concat<char, char[11], std::basic_string<char> (), std::basic_string<char>>' requested here
[composable_kernel]   248 |         return concat('_', "gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
[composable_kernel]       |                ^
```

There are plenty of other places where we use gemm_prec_str with the
operator(), so I'm pretty sure these were just typos...but I'd like some
eyes on it.

3. There are 2 tests that fail to build on Windows, which I've excluded
from the build but will open bug tickets for:
    1.  gemm_weight_preshuffle
    2.  grouped_gemm_preshuffle

Here's a sample of the compiler error for these tests:

```
[composable_kernel] [16/19] Building HIP object test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj
[composable_kernel] FAILED: [code=1] test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj
[composable_kernel] E:\TheRock\build\core\clr\dist\lib\llvm\bin\clang++.exe  -DCK_ENABLE_BF16 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_FP8 -DCK_ENABLE_INT8 -DCK_TILE_USE_WMMA=1 -DCK_TIME_KERNEL=1 -DCK_USE_OCP_FP8 -DCK_USE_WMMA -DCK_USE_WMMA_FP8 -DCK_USE_XDL -DDPP_KERNELS -DUSE_PROF_API=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -D__HIP_ROCclr__=1 -IE:/TheRock/rocm-libraries/projects/composablekernel/profiler/include -IE:/TheRock/rocm-libraries/projects/composablekernel -IE:/TheRock/rocm-libraries/projects/composablekernel/library/include -IE:/TheRock/rocm-libraries/projects/composablekernel/include -IE:/TheRock/build/ml-libs/composable_kernel/build/include -IE:/TheRock/build/base/half/stage/include -isystem E:/TheRock/build/core/clr/dist/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest/include -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/gtest-src/googletest -isystem E:/TheRock/build/ml-libs/composable_kernel/build/_deps/getopt-src/src -O3 -DNDEBUG -std=gnu++20 --offload-arch=gfx1200 -D_DLL -D_MT -Xclang --dependent-lib=msvcrt   -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Wno-missing-field-initializers -Wno-error=deprecated-declarations -Wall -Wextra -Wcomment -Wendif-labels -Wformat -Winit-self -Wreturn-type -Wsequence-point -Wswitch -Wtrigraphs -Wundef -Wuninitialized -Wunreachable-code -Wunused -Wno-reserved-identifier -Wno-option-ignored -Wsign-compare -Wno-extra-semi-stmt -Wno-unused-template -Weverything -Wno-c++98-compat -Wno-c++98-compat-pedantic -Wno-conversion -Wno-double-promotion -Wno-exit-time-destructors -Wno-extra-semi -Wno-float-conversion -Wno-gnu-anonymous-struct -Wno-gnu-zero-variadic-macro-arguments -Wno-missing-prototypes -Wno-nested-anon-types -Wno-padded -Wno-return-std-move-in-c++11 -Wno-shorten-64-to-32 -Wno-sign-conversion -Wno-unknown-warning-option -Wno-unused-command-line-argument -Wno-weak-vtables -Wno-covered-switch-default -Wno-unsafe-buffer-usage -Wno-unused-lambda-capture -Wno-nvcc-compat -Wno-c++20-compat -Wno-bit-int-extension -Wno-pass-failed -Wno-switch-default -Wno-unique-object-duplication -fbracket-depth=1024 -Wno-nrvo -Werror -Weverything -fcolor-diagnostics -Wno-c++20-extensions -Wno-global-constructors -Wno-undef -DCK_TILE_USE_OCP_FP8 -MD -MT test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -MF test\ck_tile\grouped_gemm_preshuffle\CMakeFiles\test_ck_tile_grouped_gemm_preshuffle.dir\test_grouped_gemm_preshuffle.cpp.obj.d -o test/ck_tile/grouped_gemm_preshuffle/CMakeFiles/test_ck_tile_grouped_gemm_preshuffle.dir/test_grouped_gemm_preshuffle.cpp.obj -x hip -c E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp:8:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host.hpp:6:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/host/check_err.hpp:16:
[composable_kernel] In file included from E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core.hpp:89:
[composable_kernel] E:/TheRock/rocm-libraries/projects/composablekernel/include\ck_tile/core/utility/env.hpp:110:31: warning: 'getenv' is deprecated: This function or variable may be unsafe. Consider using _dupenv_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details. [-Wdeprecated-declarations]
[composable_kernel]   110 |         const char* vp = std::getenv(name);
[composable_kernel]       |                               ^
[composable_kernel] C:\Program Files (x86)\Windows Kits\10\include\10.0.22621.0\ucrt\stdlib.h:1183:20: note: 'getenv' has been explicitly marked deprecated here
[composable_kernel]  1183 |     _Check_return_ _CRT_INSECURE_DEPRECATE(_dupenv_s)
[composable_kernel]       |                    ^
[composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:368:55: note: expanded from macro '_CRT_INSECURE_DEPRECATE'
[composable_kernel]   368 |         #define _CRT_INSECURE_DEPRECATE(_Replacement) _CRT_DEPRECATE_TEXT(    \
[composable_kernel]       |                                                       ^
[composable_kernel] C:\Program Files (x86)\Microsoft Visual Studio\2022\BuildTools\VC\Tools\MSVC\14.44.35207\include\vcruntime.h:358:47: note: expanded from macro '_CRT_DEPRECATE_TEXT'
[composable_kernel]   358 | #define _CRT_DEPRECATE_TEXT(_Text) __declspec(deprecated(_Text))
[composable_kernel]       |                                               ^
[composable_kernel] clang++: error: clang frontend command failed due to signal (use -v to see invocation)
[composable_kernel] AMD clang version 22.0.0git (https://github.com/ROCm/llvm-project.git a2dc42b87c63e686377a69f09ea23aec7550babc+PATCHED:e4d5bf498b7b8626bb9716f1f5a5946d45025918)
[composable_kernel] Target: x86_64-pc-windows-msvc
[composable_kernel] Thread model: posix
[composable_kernel] InstalledDir: E:\TheRock\build\core\clr\dist\lib\llvm\bin
[composable_kernel] clang++: note: diagnostic msg: Error generating preprocessed source(s).
[composable_kernel] ninja: build stopped: subcommand failed.
[composable_kernel FAILED WITH CODE 1 in 238 seconds]
ninja: build stopped: subcommand failed.
```

## Test Plan

Wait for internal CI and make sure build compiles locally.

## Test Result

Waiting on CI

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-03 21:55:14 +00:00
Kiefer van Teutem
b042e1805a [rocm-libraries] ROCm/rocm-libraries#4804 (commit 832dd0e)
Add Tile Distribution Encoding Register Mapping debug utility
 for MFMA / WMMA unification work. (#4804)

## Motivation

This PR adds a small utility that allows you to use Tile Distribution
Encodings to directly map matrix elements to register locations and vice
versa. It can also print forward and backward layout mappings similar to
the Matrix Calculator utility. The utility is not meant for index
calculations in actual kernels, but rather as a debugging tool and
probably for automated verification of the policy structs in the new
WMMA / MFMA unification design.

## Technical Details

Tile Distribution Encodings are a core part of CK Tile which can define
the relationship between register and intrinsic matrix fragment
elements. They allow for any mapping based on unmerge and merge
transformations. Also, they allow for a special "Repeat" dimensions
which acts like an additional matrix dimension and allows for
replication of certain matrix elements. The new mapping utility can deal
with all aspects.

## Test Plan

Since this is a debug utility there is nothing to directly test, but
there is an example file that defines four different Tile Distribution
Encodings and prints their forward and backward mappings, along with
some extra parameters.

## Test Result

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-02 16:42:42 +00:00
Linjun-AMD
78ae3835a6 [rocm-libraries] ROCm/rocm-libraries#4313 (commit 080ac66)
[CK] Fix gptoss sink

## Motivation

This PR removes conditional logic for handling infinity values in the
sink mechanism across multiple FMHA pipeline implementations, defaulting
sink_size to 0 and adding a constraint in the kernel selection logic.

## Technical Details

Changes:

Removed __builtin_isinf_sign(sink_v) checks and conditional
initialization of LSE accumulators across 7 pipeline files
Added default initialization (= 0) for sink_size in 4 argument structs
Added F_sink == "f" constraint to kernel compatibility checking

## Test Plan

Local test

## Test Result

passed

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-02 01:54:46 +00:00
Andriy Roshchenko
b661eab573 [rocm-libraries] ROCm/rocm-libraries#4821 (commit 9456e0f)
[CK TILE] Refactor MX FLATMM example

Refactor the MX FLATMM example to support more pipelines
across different architectures. This work facilitates the NPI team
roadmap.
2026-02-27 23:21:39 +00:00
Aviral Goel
c8a8449eec [rocm-libraries] ROCm/rocm-libraries#4816 (commit 17ff961)
[CK] Add split-K support for ABQuantGrouped in
 block_scale_gemm (#4816)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Changes

### Split-K support in `gemm_quant_kernel.hpp`

- **`SplitKBatchOffset`**: Added `aq_group_offset` and
`aq_k_split_offset` fields (mirroring the existing `bq_*` fields for B)
to track each split-K batch's position within the AQ scale tensor. For
`ABQuantGrouped`, both offsets are computed from `k_id * KRead` divided
by `AQuantGroupSize::kK`.

- **`MakeAQBlockWindow`**: Added an `aq_group_offset` parameter
(defaulting to 0 for non-split-K paths) so the AQ tensor view's K-group
dimension reflects only the remaining K-groups from the split-K offset,
consistent with how `MakeBQBlockWindow` handles the BQ tensor.

- **`RunGemm`**: Threads the `aq_k_split_offset` through to
`MakeAQBlockWindow` when in split-K mode.

### Constraints in `IsSupportedArgument()`

Four constraints gate split-K (`k_batch > 1`) for ABQuantGrouped:

1. **Mode check** — split-K is only allowed for `BQuantGrouped` (no
preshuffle) or `ABQuantGrouped` (no `APreshuffleQuant`). Any other quant
mode with `k_batch > 1` returns `false`.

2. **B quant group alignment** — `KRead` (per-batch K slice) must be
divisible by `BQuantGroupSize::kK`. Each batch must operate on complete
B quantization groups; a partial group would require splitting a scale
value across batches.

3. **A quant group alignment** (new, ABQuantGrouped only) — `KRead` must
also be divisible by `AQuantGroupSize::kK` for the same reason applied
to the AQ scale tensor.

4. **Minimum 2 K-tile iterations per batch** (new) — The
software-pipelined GEMM kernels (CompV3 family) prefetch one tile ahead,
so they require `per_batch_num_loop = KRead / KPerBlock >= 2`. When
`KRead == KPerBlock` (i.e. each batch is exactly one tile), the prefetch
reads into the next batch's memory region and produces incorrect
results. Configurations where `K == k_batch * KPerBlock` are therefore
rejected.

### Example update (`run_gemm_quant_example.inc`)

Updated the comment above the `IsSupportedArgument` call to document
that split-K is now supported for both `BQuantGrouped` (no preshuffle)
and `ABQuantGrouped` (no `APreshuffleQuant`).

## Unit Tests

Two new test files covering decode and prefill tile shapes across a
range of `k_batch` values (2–8), data types (FP8, BF8), and quantization
group sizes (1×1×128 and 1×128×128 for B):

- `test_gemm_quant_abquant_splitk_decode.cpp` — uses the decode tile
shape (M=16, N=64, K_tile=256)
- `test_gemm_quant_abquant_splitk_prefill.cpp` — uses the prefill tile
shape (M=128, N=128, K_tile=128)

Each test calls `run_test_with_validation` which runs the kernel and
checks correctness against a CPU reference. Configurations excluded from
tests are annotated with comments explaining which constraint they
violate (typically the `per_batch_num_loop >= 2` requirement).

## Prerequisites

This PR depends on #4429, which must be merged before this can be
merged.
2026-02-26 23:57:17 +00:00
Yung-sheng Tu
75aea70c2c [rocm-libraries] ROCm/rocm-libraries#4340 (commit 70a312f)
Implement device_grouped_gemm_fixed_nk_bias for RDNA4

## Proposed changes

Summary:

- Modified implementation for grouped_gemm_fixed_nk_bias
- FP16 WMMA examples
- WMMA instances
- Profiler for grouped_gemm_fixed_nk_bias
- Add WMMA instances to existing tests

**This PR depends on PR https://github.com/ROCm/rocm-libraries/pull/4299
and should be merged after it.
Only the last 6 commits are in the scope of this PR.**

## Checklist

Please put an `x` into the boxes that apply. You can also fill these out
after creating the PR. If you're not sure, please don't hesitate to ask.

- [x] I have added tests relevant to the introduced functionality, and
the unit tests are passing locally
- [x] I have added the test to REGRESSION_TESTS list defined at the top
of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more
than 30 seconds to run.
- [x] I have added inline documentation which enables the maintainers
with understanding the motivation
- [x] I have removed the stale documentation which is no longer relevant
after this pull request
- [ ] (If this change is user-facing) I have added release notes which
provide the end users with a brief summary of the improvement from this
pull request
- [x] I have run `clang-format` on all changed files
- [ ] Any dependent changes have been merged

## Discussion

If this is a relatively large or complex change, feel free to start a
discussion by explaining why you chose the solution you did and what
alternatives you considered

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-02-26 00:28:58 +00:00
Bartłomiej Kocot
eede24de0d [rocm-libraries] ROCm/rocm-libraries#4872 (commit ca623f7)
[CK] Small improvements for grouped conv backward weight
 (#4872)

## Motivation

Improvements for CK Tile convolution builder run function and atol/rtol
calculations.

## Technical Details

- Add preprocessing function for wrw when k_batch is larger than 1 for
builder run function
- Divide num acums by number of groups to get real number of accums

## Test Plan

CI wrw tests

## Test Result

pending

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-783
2026-02-25 20:11:01 +00:00
Brock Hargreaves
c90a363589 [rocm-libraries] ROCm/rocm-libraries#4812 (commit bb5a4dd)
[CK] Use as_posix() instead of str() for paths in
 fmha_fwd_appendkv.py (#4812)

## Motivation

This is causing a failing PR for Windows:
https://github.com/ROCm/TheRock/pull/3382
```

[composable_kernel configure] -- Jenga kernel files to be generated: B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_mask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_mask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_bf16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_bf16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_bf16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_mask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_bf16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_mask_nskip_nsquant_ntrload.cpp;B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_api.cpp
[composable_kernel configure] CMake Error at example/ck_tile/50_sparse_attn/CMakeLists.txt:61 (add_library):
[composable_kernel configure]   Syntax error in cmake code when parsing string
[composable_kernel configure]
[composable_kernel configure]     B:\build\ml-libs\composable_kernel\build\example\ck_tile\50_sparse_attn\fmha_jenga_fwd_d128_fp16_batch_b128x128x32x128x32x128_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psddv_nlogits_nbias_nmask_nskip_nsquant_ntrload.cpp
[composable_kernel configure]
[composable_kernel configure]   Invalid character escape '\b'.
```

## Technical Details

The file:
[fmha_fwd_appendkv.py](https://github.com/ROCm/rocm-libraries/compare/users/brockhargreaves-amd/ck/fix-windows-cmake-path-problem?expand=1#diff-bef22bf9ba21eb93c725493ecc7edcb6f2a8f0a9a173dcfca6bda7a9f4eced78)
writes a bunch of paths to a text file which is later parsed by cmake.
When passing a pathlib.Path to str(), str() converts to a native path,
in this case / to \\ on Windows which confuses cmake. In this case we
need to write paths with forward slashes and then pass those onward to
cmake.

## Test Plan

1. Ensure this doesn't impact existing CI.
2. Ensure compilation of Windows pass locally.

## Test Result

1. Passes existing CI
2. This fixes the compilation error locally.

## Submission Checklist

- [ x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-02-25 16:13:56 +00:00
Brock Hargreaves
abf13bdec1 [rocm-libraries] ROCm/rocm-libraries#4819 (commit b995a0b)
[CK] Fix windows build issues

## Motivation

Full build on Windows is currently broken due to compiler errors, this
PR should help fix that. This is also holding up the following PR in the
TheRock: https://github.com/ROCm/TheRock/pull/3382

## Technical Details

1. I don't see a good reason to be nesting a windows include inside the
ck_tile namespace. It was causing compiler errors too: Windows.h comes
with min and max, which was conflicting with ck_tile::min and
ck_tile::max, so I moved it out. I also defined NOMINMAX to prevent this
inclusion in the future.
2. The TRUE/FALSE macros are already used by Windows.h, which causes an
error. So I've opted for True/False. You can see this pattern in other
rocm-libraries.
3. The M_PI macro isn't available, at least in the WIN32_LEAN_AND_MEAN
context, from \<cmath\> on Windows. We'll be able to use
std::numbers::v_pi\<float\> when we have C++20 support.
4. There was a missing \<chrono\> include.

## Test Plan

Test locally and make sure this doesn't impact existing CI.

## Test Result

Compiles locally and passes existing ci.

## Submission Checklist

- [ x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-02-25 16:13:13 +00:00