MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)

* Add gemm_mx_fp8_bf8 example with row-major B

* Add more overloads of MX MFMA instructions

* Add MK_KN (RRR) tests

* Add KM_NK (CCR) tests

* Add more problem sizes to Large tests

* Add test_gemm_mx to the list of regression tests
This commit is contained in:
Andriy Roshchenko
2025-05-01 11:55:48 -06:00
committed by GitHub
parent b9d17bdb11
commit 79b0bfeb41
15 changed files with 642 additions and 18 deletions

View File

@@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto WaveSize = 64;
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl;
constexpr auto KThreadRead = WaveSize / MPerXdl;
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
@@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
}
else // RowMajor B
{
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto WaveSize = 64;
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl;
constexpr auto KThreadRead = WaveSize / NPerXdl;
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)