mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[CK_TILE] Enable full transpose layout support for MX GEMM pipeline (#5813) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Enable full transpose layout support for MX GEMM pipeline (32x32x64 MFMA) ### Summary This PR enables all four matrix layout combinations (Row/Col, Row/Row, Col/Col, Col/Row) for the MX GEMM pipeline with `32x32x64` MFMA warp tiles, using `ds_read_tr` transposed LDS loads on gfx950. Previously, only the canonical `A=RowMajor, B=ColumnMajor` layout was supported. ### Changes **Kernel-side transpose support:** - **`warp_gemm_attribute_mfma.hpp`**: Introduce `kSplitFactor` logic in `get_warp_dstr_encoding` to split the K-dimension distribution encoding when `kPerLane` exceeds the `ds_read_tr` subtile minor dimension. This satisfies the `TransposeTileDistributionTraits` suffix validation required by `load_tile_transpose`. The distribution encoding now also receives the `DataType` template parameter to compute the split factor based on packed element size. - **`gemm_pipeline_ag_bg_cr_comp_async.hpp`**: Uncomment and enable the `InputTileDistributionTraits` logic to properly transform LDS load tile distributions for transposed reads. Add `static_assert`s to catch misconfigurations where a layout requires transpose loads but the warp tile size disables them (e.g. `KWarpTile=128` exceeds `ds_read_tr` limits). - **`load_tile_transpose.hpp`**: Fix `DataVec` sizing for packed types (`pk_fp4_t`) — divide `vecLoadSize` by `PackedSize` to prevent buffer overflow when each physical element contains multiple logical values. - **`warp_gemm_attribute_mfma_impl.hpp`**: Set `kDefaultScale` to `0x7F7F7F7F` (unity in e8m0 format) for the unscaled `operator()` overloads of `WarpGemmAttributeMfmaImpl_f32_32x32x64_f8f6f4`, ensuring correct behavior with `mfma_scale_f32_32x32x64_f8f6f4`. - **`warp_gemm.hpp` / `warp_gemm_dispatcher.hpp`**: Add generic `WarpGemmMfma_f32_32x32x64_f8f6f4<A, B>` alias and dispatcher specialization to support arbitrary MX data type combinations (fp4, fp6, fp8) with the 32x32x64 MFMA, consolidating the existing type-specific aliases. - **`gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp`**: Simplify `wg_attr_num_access` determination — `Double` for fp8, `Single` otherwise. **Reference implementation fix:** - **`reference_gemm.hpp`**: Fix nibble selection for packed 4-bit types (`pk_fp4_t`, `pk_int4_t`) in `reference_mx_gemm`, `reference_gemm`, and `reference_gemm_abquant`. The previous logic used `k % 2` or `index[K_DIM] & 1` to select which nibble to extract, which assumed K was always the fast (contiguous) memory dimension. This is only true for `A=RowMajor` / `B=ColumnMajor`. For other layouts, the fix computes the flat memory offset via `mDesc.GetOffsetFromMultiIndex(...)` and uses its parity to correctly select the nibble regardless of layout. **Test infrastructure:** - **`test_mx_gemm_config.hpp`**: Add `MxGemmConfig32` base and `MXfp4_GemmConfig32` / `MXfp8_GemmConfig32` configs for the 32x32x64 warp tile. - **`test_mx_gemm_fp4.cpp` / `test_mx_gemm_fp8.cpp`**: Add `Config32` test suites covering all four layout combinations. Restrict `Config16` (16x16x128) to `A=Row, B=Col` only, since `KWarpTile=128` exceeds `ds_read_tr` limits. - **`test_mx_gemm_util.hpp`**: Fix scale tensor layout — scales are always row-major `[M, K/32]` and column-major `[K/32, N]`, independent of A/B data layout. ### Test plan - [x] `test_ck_tile_mx_gemm_fp4` — 5/5 passed (16x16x128 Row/Col + 32x32x64 all 4 layouts) - [x] `test_ck_tile_mx_gemm_fp8` — 5/5 passed (16x16x128 Row/Col + 32x32x64 all 4 layouts) - [x] `test_ck_tile_mx_gemm_fp6` — 1/1 passed (16x16x128 Row/Col)
The file is empty.