mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)
* Add gemm_mx_fp8_bf8 example with row-major B * Add more overloads of MX MFMA instructions * Add MK_KN (RRR) tests * Add KM_NK (CCR) tests * Add more problem sizes to Large tests * Add test_gemm_mx to the list of regression tests
This commit is contained in:
committed by
GitHub
parent
b9d17bdb11
commit
79b0bfeb41
@@ -797,12 +797,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg offset
|
||||
// for compiler.
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
|
||||
constexpr auto M1 = MPerBlock / M0;
|
||||
|
||||
constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
|
||||
@@ -929,12 +930,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
}
|
||||
else // RowMajor B
|
||||
{
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
constexpr auto WaveSize = 64;
|
||||
constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
|
||||
constexpr auto N1 = NPerBlock / N0;
|
||||
|
||||
constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
|
||||
constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
|
||||
constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl;
|
||||
constexpr auto KThreadRead = WaveSize / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
|
||||
|
||||
constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
|
||||
|
||||
@@ -1129,6 +1129,12 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
@@ -1147,6 +1153,18 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user