mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)
* Add gemm_mx_fp8_bf8 example with row-major B * Add more overloads of MX MFMA instructions * Add MK_KN (RRR) tests * Add KM_NK (CCR) tests * Add more problem sizes to Large tests * Add test_gemm_mx to the list of regression tests
This commit is contained in:
committed by
GitHub
parent
b9d17bdb11
commit
79b0bfeb41
@@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
|
||||
Reference in New Issue
Block a user