mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Remove CK_USE_AMD_MFMA_GFX950 (#1935)
* Add runtime check in example_gemm_xdl_streamk for gfx950 * Add runtime check in grouped conv fwd examples for gfx950 * Disable CK_USE_AMD_MFMA_GFX950 * Add new instances for gfx950 * Fix test_gemm_universal on gfx950
This commit is contained in:
@@ -517,8 +517,9 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
|
||||
@@ -450,8 +450,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -361,9 +361,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
math::lcm(A0K1, B0K1) <= 4)
|
||||
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4) ||
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
|
||||
@@ -653,8 +655,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4)
|
||||
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4) ||
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
|
||||
@@ -343,9 +343,11 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
math::lcm(AK1, BK1) <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto mfma =
|
||||
@@ -560,8 +562,9 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -471,8 +471,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -500,8 +500,9 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -674,10 +674,22 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -466,8 +466,9 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -635,10 +635,22 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -600,10 +600,22 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
// b_mtx[K0PerBlock, NPerBlock] is in LDS
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
|
||||
.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -601,8 +601,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -453,8 +453,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -583,8 +583,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
@@ -1015,8 +1016,9 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
|
||||
@@ -597,8 +597,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -81,8 +81,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
|
||||
@@ -141,8 +141,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
|
||||
@@ -810,9 +810,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -871,8 +871,9 @@ struct GridwiseGemm_xdl_cshuffle_v2
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -149,8 +149,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
|
||||
@@ -157,8 +157,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
|
||||
@@ -193,9 +193,17 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
using BsGridPointer = decltype(MakeBsGridPointer());
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -179,9 +179,18 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -149,9 +149,18 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
|
||||
using DsGridPointer = decltype(MakeDsGridPointer());
|
||||
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk);
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -491,8 +491,9 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -489,8 +489,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
|
||||
// branch early for math wave
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
|
||||
@@ -741,11 +741,17 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted>::selected_mfma
|
||||
.k_per_blk);
|
||||
constexpr index_t KPack = math::max(
|
||||
K1,
|
||||
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
|
||||
@@ -448,8 +448,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t k_pack = math::max(
|
||||
|
||||
Reference in New Issue
Block a user