13 Commits

Author SHA1 Message Date
Sami Remes
a3a12b8945 [rocm-libraries] ROCm/rocm-libraries#5813 (commit 18b43cf)
[CK_TILE] Enable full transpose layout support for MX GEMM
 pipeline (#5813)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Enable full transpose layout support for MX GEMM pipeline (32x32x64
MFMA)

### Summary

This PR enables all four matrix layout combinations (Row/Col, Row/Row,
Col/Col, Col/Row) for the MX GEMM pipeline with `32x32x64` MFMA warp
tiles, using `ds_read_tr` transposed LDS loads on gfx950. Previously,
only the canonical `A=RowMajor, B=ColumnMajor` layout was supported.

### Changes

**Kernel-side transpose support:**

- **`warp_gemm_attribute_mfma.hpp`**: Introduce `kSplitFactor` logic in
`get_warp_dstr_encoding` to split the K-dimension distribution encoding
when `kPerLane` exceeds the `ds_read_tr` subtile minor dimension. This
satisfies the `TransposeTileDistributionTraits` suffix validation
required by `load_tile_transpose`. The distribution encoding now also
receives the `DataType` template parameter to compute the split factor
based on packed element size.

- **`gemm_pipeline_ag_bg_cr_comp_async.hpp`**: Uncomment and enable the
`InputTileDistributionTraits` logic to properly transform LDS load tile
distributions for transposed reads. Add `static_assert`s to catch
misconfigurations where a layout requires transpose loads but the warp
tile size disables them (e.g. `KWarpTile=128` exceeds `ds_read_tr`
limits).

- **`load_tile_transpose.hpp`**: Fix `DataVec` sizing for packed types
(`pk_fp4_t`) — divide `vecLoadSize` by `PackedSize` to prevent buffer
overflow when each physical element contains multiple logical values.

- **`warp_gemm_attribute_mfma_impl.hpp`**: Set `kDefaultScale` to
`0x7F7F7F7F` (unity in e8m0 format) for the unscaled `operator()`
overloads of `WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4`, ensuring
correct behavior with `mfma_scale_f32_32x32x64_f8f6f4`.

- **`warp_gemm.hpp` / `warp_gemm_dispatcher.hpp`**: Add generic
`WarpGemmMfma_f32_32x32x64_f8f6f4<A, B>` alias and dispatcher
specialization to support arbitrary MX data type combinations (fp4, fp6,
fp8) with the 32x32x64 MFMA, consolidating the existing type-specific
aliases.

- **`gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp`**: Simplify
`wg_attr_num_access` determination — `Double` for fp8, `Single`
otherwise.

**Reference implementation fix:**

- **`reference_gemm.hpp`**: Fix nibble selection for packed 4-bit types
(`pk_fp4_t`, `pk_int4_t`) in `reference_mx_gemm`, `reference_gemm`, and
`reference_gemm_abquant`. The previous logic used `k % 2` or
`index[K_DIM] & 1` to select which nibble to extract, which assumed K
was always the fast (contiguous) memory dimension. This is only true for
`A=RowMajor` / `B=ColumnMajor`. For other layouts, the fix computes the
flat memory offset via `mDesc.GetOffsetFromMultiIndex(...)` and uses its
parity to correctly select the nibble regardless of layout.

**Test infrastructure:**

- **`test_mx_gemm_config.hpp`**: Add `MxGemmConfig32` base and
`MXfp4_GemmConfig32` / `MXfp8_GemmConfig32` configs for the 32x32x64
warp tile.
- **`test_mx_gemm_fp4.cpp` / `test_mx_gemm_fp8.cpp`**: Add `Config32`
test suites covering all four layout combinations. Restrict `Config16`
(16x16x128) to `A=Row, B=Col` only, since `KWarpTile=128` exceeds
`ds_read_tr` limits.
- **`test_mx_gemm_util.hpp`**: Fix scale tensor layout — scales are
always row-major `[M, K/32]` and column-major `[K/32, N]`, independent
of A/B data layout.

### Test plan

- [x] `test_ck_tile_mx_gemm_fp4` — 5/5 passed (16x16x128 Row/Col +
32x32x64 all 4 layouts)
- [x] `test_ck_tile_mx_gemm_fp8` — 5/5 passed (16x16x128 Row/Col +
32x32x64 all 4 layouts)
- [x] `test_ck_tile_mx_gemm_fp6` — 1/1 passed (16x16x128 Row/Col)
2026-06-18 17:05:09 +00:00
Sami Remes
c1f7104852 [rocm-libraries] ROCm/rocm-libraries#6663 (commit f19fc01)
[CKTile] Fix MX GEMM: num_loop==3 dispatch, split-K,
 unsupported-shape guard (#6663)

Three independent MX GEMM correctness bugs reported against
example/ck_tile/42_mx_gemm (fp8xfp8, A=Row/B=Col) on MI350X, plus one
host-side atomic-add accumulation bug in the example's repeat loop.

- Pipeline (gemm_pipeline_ag_bg_cr_comp_async.hpp): BlockHasHotloop
required num_loop > PrefetchStages, which let num_loop == 3 enter a hot
loop that produced 5 gemm accumulations instead of 3 (K == 3*K_Tile,
e.g. K=768, deterministically wrong). Require num_loop >= 4 instead:
pre-pipeline + TailNumber::Three already totals exactly 3.

- Kernel (gemm_mx_kernel.hpp): split-K was silently broken because
GridSize did not thread k_batch into blockIdx.z and the scale tile
windows were anchored at K=0 for every k_id. Every k_id >= 1 therefore
read the wrong packed scales. Fix:
* GridSize returns dim3(grid_x, 1, k_batch) (persistent and
non-persistent).
* MakeScaleA/BBlockWindows accept a k_elem_offset and translate it to a
packed-scale K offset (also apply pad_tensor_view so OOB scale loads
return zero, matching A/B padding).
* operator() derives k_id from blockIdx.z, uses GetSplitKElemOffset
(matches Underlying::SplitKBatchOffset's K1-aligned formula), and
dispatches the epilogue with memory_operation_enum::atomic_add for
k_batch > 1, set for k_batch == 1. Same fp16/bf16 even-vector-size guard
as UniversalGemmKernel.
* MakeCBlockWindows templated on DstInMemOp; unconditionally applies
pad_tensor_view using kPadM/kPadN so partial trailing M/N tiles are
handled correctly.

- Compile- and runtime unsupported-shape guards (gemm_mx_kernel.hpp):
add IsSupportedArgument and a static_assert for configurations that
produce silent wrong results:
* static_assert(!kPadK) -- the MX comp-async pipeline uses
async_load_tile whose OOB check is per-vector-start, so a vector
straddling the K pad boundary reads garbage. Until the async path learns
per-element pad masking, reject kPadK at compile time.
* Runtime: k_batch >= 1; M/N multiples of MPerBlock/NPerBlock when
kPadM/kPadN are false; M >= MPerBlock and N >= NPerBlock always
(CShuffleEpilogue cannot safely run with a single partial tile); K %
(KPerBlock * k_batch) == 0; and for k_batch > 1, K must be a multiple of
WarpTile_K * k_batch so every split lands on a packed-scale boundary.
  * All error paths log under CK_TILE_LOGGING with actionable messages.

- Example (example/ck_tile/42_mx_gemm/mx_gemm_instance.hpp):
* Call Kernel::IsSupportedArgument up front and throw a clear
runtime_error for rejected shapes (was silently launching an unsupported
kernel).
* Switch to launch_kernel_time_mask with a clear_gemm_output preprocess
that zeroes C between iterations when k_batch > 1 (mirrors
universal_gemm_invoker). Without this the default -warmup=50 -repeat=100
accumulated 150 atomic_adds into C after the kernel-side split-K fix.

Tests (test/ck_tile/gemm_mx/):
- Add MXfp8_GemmConfig16_PadMN (kPadM = kPadN = true).
- test_mx_gemm_fp8.cpp: HotLoopTailNumLoopThree (K=768 regression),
SplitK (k_batch=2,4 across full_k/partial_k paths),
TestMxGemmFp8PadMN::{MNPaddingAligned, MPadding, NPadding, MNPadding}
covering trailing partial tiles along M, N, or both.
- Run(...) now takes k_batch.
- packScalesMNxK: guard against OOB (mn, k) reads from src and
initialise e8m0 bytes to the zero exponent (0x00) instead of the
default-constructed NaN (0xFF), so padded lanes don't poison the packed
int32_t shared with in-range lanes.
- test_mx_gemm_instance.hpp: call IsSupportedArgument before launch.

Verification on gfx950, ROCm 7.2.0:
- ctest -R test_ck_tile_mx_gemm -> 100% (2/2).
- Example sweep over the original bug-report shapes: all K-aligned
shapes now validate correct (including 4096^3 sk=2 and the K=768 cases);
all K=128 shapes cleanly rejected with the new error message instead of
producing silent wrong results.

Made-with: Cursor

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-15 08:28:55 +00:00
Enrico Degregori
1b4fbd95fd [rocm-libraries] ROCm/rocm-libraries#6089 (commit c876d18)
[CK Tile] Extend type support EightWave pipeline

## Motivation

EightWave pipeline was designed for 8 bit types. This PR extend support
for any FP type

## Technical Details

 - Generalize policy to support any FP type
- Change LDS layout to fix bank conflicts. This removes all bank
conflicts in the pipeline (checked for all supported types). Remaining
bank conflicts are related to Cshuffle epilogue.

## Test Plan

Added GEMM tests with new supported types. Note that FP6 is also
supported for MX GEMM but the PR was reverted so no tests were added for
it.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-05 23:54:40 +00:00
Illia Silin
c24e528481 [rocm-libraries] ROCm/rocm-libraries#7760 (commit a61bc76)
[CK] suppress compiler warnings while building pytorch. (#7760)

## Motivation

Recently added compiler flags that are required to suppress false
warnings by latest staging compiler are not recognized by older compiler
versions and are triggering an avalanche of warnings. Previous attempt
to suppress them by using -Wno-unknown-warning-option flag didn't help,
because that flag wasn't recognized either and just added more warnings.
I've verified that current approach by checking the clang version
actually works as intended and makes the warnings go away.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-27 06:56:58 -07:00
JP-Fernando
74bc86240b [rocm-libraries] ROCm/rocm-libraries#5647 (commit 490437a)
[CK Tile] Add gemm universal preshuffle to MX GEMM  (#5647)

## Motivation

Add gemm universal preshuffle support to existing MX GEMM pipeline.

The straightforward way to do this is to port the `mx_flatmm` pipeline
to the existing `gemm_mx` framework.

## Technical Details

The `mx_flatmm` pipeline was not deleted, to allow for
back-compatibility.

## Test Plan

Add `preshuffle` option to example: `tile_example_mx_gemm`.

Add new configurations with enabled preshuffle to the existing
`test/ck_tile/gemm_mx` tests.

## Test Result

Example and tests were successful on `gf950` architecture in the `Alola`
cluster.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
2026-05-22 16:07:53 +02:00
JP-Fernando
e7798e9560 [rocm-libraries] ROCm/rocm-libraries#7112 (commit a6e5eac)
Add asynchronous XOR shuffle support to the Async GEMM pipeline and the MX GEMM pipeline (#7112)

## Motivation

The goal of this work is to apply XOR shuffle (swizzle) to the current
`comp_async` GEMM pipeline and the `gemm_mx` pipeline.
XOR swizzling has been helpful to avoid LDS bank conflicts, as data are
redistributed across LDS banks, such that simultaneous threads accessing
different rows land on different LDS banks.

## Technical Details

A similar approach to the work in the existing eight-waves pipeline was
followed.
Currently, XOR swizzle support is available for FP8 and BF8 types.
FP4 support is also available for MX GEMM.
Should the types not match, or should the async vector width be of an
unsupported size, then the pipeline falls through to the previously
existing ('unswizzled') path.

## Test Plan

Execute `test_ck_tile_gemm_pipeline_comp_async` for the Async GEMM
pipeline.
Execute `test_ck_tile_mx_gemm_fp8` and `test_ck_tile_mx_gemm_fp4` for
the MX GEMM pipeline.

## Test Result

The tests passed successfully in the `Alola` cluster with MI350
hardware.

## Submission Checklist

- [X] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Fernando Jiménez <fernando.jimenez@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
2026-05-21 09:36:41 +02:00
Enrico Degregori
9565ca21ec [rocm-libraries] ROCm/rocm-libraries#5552 (commit 369c7a2)
[CK Tile] Eight Waves pipeline for MX GEMM (#5552)

## Motivation

Integrate Eight Waves pipeline in MX GEMM

## Technical Details

 - EightWaves pipeline:
- Add pipeline, policy and block gemm (internally using existing
implementation used by GEMM and ABQuant)
   - Extend support of EightWaves policy for FP4 (packed types)
 - Async pipeline:
- Fix pipeline with packed scales (requires MRepeat and NRepeat to be
contiguous)
- block gemm specific for MX GEMM is defined because distribution
encodings have changed
 - CShuffle:
- Add new functionality to support MRepeat and NRepeat contiguous
(defined by `TilesPacked`)
 - Examples:
- Refactor examples to easily switch different configurations (similar
to GEMM universal)
- Scales values generated consistently with other microscale
implementations in CK Tile
   - Add configuration for EightWaves pipeline
 - Tests:
   - Unify existing FP8 and FP4 tests
   - Add tests for EightWaves pipeline
- Scales values generated consistently with other microscale
implementations in CK Tile

Note: FP6 support for MX GEMM was added later and the support for the
Eight Waves pipeline will be done in following PR

## Test Plan

Add new pipeline to tests: `test_ck_tile_mx_gemm_async` for both FP4 and
FP8

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-19 11:53:19 -07:00
Illia Silin
ac18460782 [rocm-libraries] ROCm/rocm-libraries#7384 (commit 10e9d70)
[CK] Suppress new staging compiler errors (#7384)

## Motivation

This should make new builds with staging compiler pass.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-14 12:51:08 -07:00
Illia Silin
d16061f578 [rocm-libraries] ROCm/rocm-libraries#6550 (commit c396de9)
[CK] Fix/suppress clang lifetimebound warnings with staging compiler. (#6550)

## Motivation

New changes from upstream llvm-project cause an avalanche of warnings in
CK. Gonna disable them by ignoring the
lifetime-safety-intra-tu-suggestions flag until a better permanent
solution is found.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-22 15:47:47 +00:00
Linjun-AMD
7469320248 [rocm-libraries] ROCm/rocm-libraries#5849 (commit d9b89b2)
[CK_TILE ]Revert "[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)" (#5849)

This reverts commit 7e55766ddf7e9e20791b0e4e2d7b4026cf16b637.

## Motivation

<!-- Explain the purpose of this PR and the goals it aims to achieve.
-->

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-27 20:36:39 +00:00
Sami Remes
8f0ede3ea2 [rocm-libraries] ROCm/rocm-libraries#5095 (commit 7e55766)
[CK_TILE] Enable MXFP6 for MX GEMM op (#5095)

## Motivation

Add support for MXFP6 in the MX GEMM op in CK-Tile.

Depends on https://github.com/ROCm/rocm-libraries/pull/4594

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-03-19 18:07:47 -07:00
Thomas Ning
1ab29bf22f [rocm-libraries] ROCm/rocm-libraries#5323 (commit 5454e9e)
CK Tile MX GEMM Packing Improvement (#5323)

## Motivation

Reduce the scale loading size and also has better utilization of MFMA
scale selection.

## Technical Details

Add up the packing of mx scales.

## Test Plan

Use the existing test cases.

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Sami Remes <samremes@amd.com>
Co-authored-by: Enrico Degregori <enrico@streamhpc.com>
2026-03-17 11:57:32 -07:00
Sami Remes
c1525b3f30 [rocm-libraries] ROCm/rocm-libraries#4594 (commit 1fce4cb)
[CK_TILE] MX GEMM non-preshuffled RCR layout (#4594)

## Motivation

Implements a GEMM with MX scaling for fp4 and fp8 in non-preshuffled
layouts using async pipeline.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: ThomasNing <thomas.ning@amd.com>
2026-03-10 20:12:05 +00:00