Remove CK_USE_AMD_MFMA_GFX950 (#1935)

* Add runtime check in example_gemm_xdl_streamk for gfx950

* Add runtime check in grouped conv fwd examples for gfx950

* Disable CK_USE_AMD_MFMA_GFX950

* Add new instances for gfx950

* Fix test_gemm_universal on gfx950
This commit is contained in:
jefyang1
2025-03-04 10:32:25 -08:00
committed by GitHub
parent 540a6da40b
commit c95bda93ba
186 changed files with 3272 additions and 883 deletions

View File

@@ -517,8 +517,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =

View File

@@ -450,8 +450,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -361,9 +361,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
constexpr bool is_single_rate_mfma =
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
math::lcm(A0K1, B0K1) <= 4)
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4) ||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
? true
: false;
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
@@ -653,8 +655,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// sanity check
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
constexpr bool is_single_rate_mfma =
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4)
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4) ||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
? true
: false;
constexpr index_t KPack =

View File

@@ -343,9 +343,11 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
const auto M = d0_grid_desc_m_n.GetLength(I0);
const auto N = d0_grid_desc_m_n.GetLength(I1);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
math::lcm(AK1, BK1) <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr auto mfma =
@@ -560,8 +562,9 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -471,8 +471,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -500,8 +500,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -674,10 +674,22 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -466,8 +466,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -635,10 +635,22 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -600,10 +600,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -601,8 +601,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -453,8 +453,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -583,8 +583,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =
@@ -1015,8 +1016,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack =

View File

@@ -597,8 +597,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -81,8 +81,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =

View File

@@ -141,8 +141,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =

View File

@@ -810,9 +810,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -871,8 +871,9 @@ struct GridwiseGemm_xdl_cshuffle_v2
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -149,8 +149,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =

View File

@@ -157,8 +157,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =

View File

@@ -193,9 +193,17 @@ struct GridwiseGemm_xdl_cshuffle_v3
using BsGridPointer = decltype(MakeBsGridPointer());
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -179,9 +179,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -149,9 +149,18 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
using DsGridPointer = decltype(MakeDsGridPointer());
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -491,8 +491,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -489,8 +489,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
// branch early for math wave
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t KPack = math::max(

View File

@@ -741,11 +741,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr bool is_single_rate_mfma =
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
K1 <= 4) ||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8))
? true
: false;
constexpr index_t KPack =
math::max(K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
.k_per_blk);
constexpr index_t KPack = math::max(
K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted, is_single_rate_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,

View File

@@ -448,8 +448,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// sanity check
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
constexpr bool is_single_rate_mfma =
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4)
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
? true
: false;
constexpr index_t k_pack = math::max(