[CK_TILE] Enable full transpose layout support for MX GEMM pipeline (#5813) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Enable full transpose layout support for MX GEMM pipeline (32x32x64 MFMA) ### Summary This PR enables all four matrix layout combinations (Row/Col, Row/Row, Col/Col, Col/Row) for the MX GEMM pipeline with `32x32x64` MFMA warp tiles, using `ds_read_tr` transposed LDS loads on gfx950. Previously, only the canonical `A=RowMajor, B=ColumnMajor` layout was supported. ### Changes **Kernel-side transpose support:** - **`warp_gemm_attribute_mfma.hpp`**: Introduce `kSplitFactor` logic in `get_warp_dstr_encoding` to split the K-dimension distribution encoding when `kPerLane` exceeds the `ds_read_tr` subtile minor dimension. This satisfies the `TransposeTileDistributionTraits` suffix validation required by `load_tile_transpose`. The distribution encoding now also receives the `DataType` template parameter to compute the split factor based on packed element size. - **`gemm_pipeline_ag_bg_cr_comp_async.hpp`**: Uncomment and enable the `InputTileDistributionTraits` logic to properly transform LDS load tile distributions for transposed reads. Add `static_assert`s to catch misconfigurations where a layout requires transpose loads but the warp tile size disables them (e.g. `KWarpTile=128` exceeds `ds_read_tr` limits). - **`load_tile_transpose.hpp`**: Fix `DataVec` sizing for packed types (`pk_fp4_t`) — divide `vecLoadSize` by `PackedSize` to prevent buffer overflow when each physical element contains multiple logical values. - **`warp_gemm_attribute_mfma_impl.hpp`**: Set `kDefaultScale` to `0x7F7F7F7F` (unity in e8m0 format) for the unscaled `operator()` overloads of `WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4`, ensuring correct behavior with `mfma_scale_f32_32x32x64_f8f6f4`. - **`warp_gemm.hpp` / `warp_gemm_dispatcher.hpp`**: Add generic `WarpGemmMfma_f32_32x32x64_f8f6f4<A, B>` alias and dispatcher specialization to support arbitrary MX data type combinations (fp4, fp6, fp8) with the 32x32x64 MFMA, consolidating the existing type-specific aliases. - **`gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp`**: Simplify `wg_attr_num_access` determination — `Double` for fp8, `Single` otherwise. **Reference implementation fix:** - **`reference_gemm.hpp`**: Fix nibble selection for packed 4-bit types (`pk_fp4_t`, `pk_int4_t`) in `reference_mx_gemm`, `reference_gemm`, and `reference_gemm_abquant`. The previous logic used `k % 2` or `index[K_DIM] & 1` to select which nibble to extract, which assumed K was always the fast (contiguous) memory dimension. This is only true for `A=RowMajor` / `B=ColumnMajor`. For other layouts, the fix computes the flat memory offset via `mDesc.GetOffsetFromMultiIndex(...)` and uses its parity to correctly select the nibble regardless of layout. **Test infrastructure:** - **`test_mx_gemm_config.hpp`**: Add `MxGemmConfig32` base and `MXfp4_GemmConfig32` / `MXfp8_GemmConfig32` configs for the 32x32x64 warp tile. - **`test_mx_gemm_fp4.cpp` / `test_mx_gemm_fp8.cpp`**: Add `Config32` test suites covering all four layout combinations. Restrict `Config16` (16x16x128) to `A=Row, B=Col` only, since `KWarpTile=128` exceeds `ds_read_tr` limits. - **`test_mx_gemm_util.hpp`**: Fix scale tensor layout — scales are always row-major `[M, K/32]` and column-major `[K/32, N]`, independent of A/B data layout. ### Test plan - [x] `test_ck_tile_mx_gemm_fp4` — 5/5 passed (16x16x128 Row/Col + 32x32x64 all 4 layouts) - [x] `test_ck_tile_mx_gemm_fp8` — 5/5 passed (16x16x128 Row/Col + 32x32x64 all 4 layouts) - [x] `test_ck_tile_mx_gemm_fp6` — 1/1 passed (16x16x128 Row/Col)
CK Tile Testing Guide
This document describes the test organization and available test targets for CK Tile operations.
Overview
CK Tile tests are organized with multiple levels of granularity to support different development workflows:
- Global test labels - Run tests across all operations
- Operation-specific umbrella targets - Run all tests for a specific operation
- Individual test executables - Run specific tests
Global Test Labels
These targets run tests across all CK operations (not just CK Tile):
ninja smoke
Run fast smoke tests (tests that complete within ~30 seconds on gfx90a).
ninja smoke
ninja regression
Run slower, more comprehensive regression tests.
ninja regression
ninja check
Run ALL available tests in the entire codebase.
ninja check
Operation-Specific Umbrella Targets
These targets allow you to run all tests for a specific CK Tile operation. This is useful when making changes to a particular operation and wanting to validate all related tests without running the entire test suite.
GEMM Operations
ck_tile_gemm_tests
Run all basic GEMM pipeline tests (memory, compute variants, persistent, etc.)
ninja ck_tile_gemm_tests
Test executables included:
test_ck_tile_gemm_pipeline_memtest_ck_tile_gemm_pipeline_compv3test_ck_tile_gemm_pipeline_compv4test_ck_tile_gemm_pipeline_persistenttest_ck_tile_gemm_pipeline_compv6test_ck_tile_gemm_pipeline_comp_async(gfx95 only)test_ck_tile_gemm_pipeline_*_wmmavariants (gfx11/gfx12 only)
ck_tile_gemm_block_scale_tests
Run all GEMM tests with block-scale quantization (AQuant, BQuant, ABQuant, etc.)
ninja ck_tile_gemm_block_scale_tests
Test executables included: 29 test executables covering:
- AQuant tests (memory pipelines, base layouts, prefill, preshuffle, transpose)
- ABQuant tests (base, padding, preshuffle)
- BQuant tests (1D/2D variants, transpose)
- BQuant with PreshuffleB (decode/prefill, 1D/2D)
- BQuant with PreshuffleQuant (decode/prefill, 1D/2D)
- RowColQuant and TensorQuant tests
ck_tile_gemm_streamk_tests
Run all GEMM StreamK tests (tile partitioner, reduction, smoke, extended)
ninja ck_tile_gemm_streamk_tests
Test executables included:
test_ck_tile_streamk_tile_partitionertest_ck_tile_streamk_reductiontest_ck_tile_streamk_smoketest_ck_tile_streamk_extended
ck_tile_grouped_gemm_quant_tests
Run all grouped GEMM quantization tests
ninja ck_tile_grouped_gemm_quant_tests
Test executables included:
test_ck_tile_grouped_gemm_quant_rowcoltest_ck_tile_grouped_gemm_quant_tensortest_ck_tile_grouped_gemm_quant_aquanttest_ck_tile_grouped_gemm_quant_bquanttest_ck_tile_grouped_gemm_quant_bquant_preshuffleb
Other Operations
ck_tile_fmha_tests
Run all FMHA (Flash Multi-Head Attention) tests
ninja ck_tile_fmha_tests
Test executables included: Forward and backward tests for fp16, bf16, fp8bf16, fp32
ck_tile_reduce_tests
Run all reduce operation tests
ninja ck_tile_reduce_tests
Test executables included:
test_ck_tile_reduce2dtest_ck_tile_multi_reduce2d_threadwisetest_ck_tile_multi_reduce2d_multiblock
Individual Test Executables
You can also build and run individual test executables:
Build a specific test
ninja test_ck_tile_gemm_pipeline_mem
Run a specific test directly
./build/bin/test_ck_tile_gemm_pipeline_mem
Run a specific test through ctest
ctest -R test_ck_tile_gemm_pipeline_mem --output-on-failure