mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Fix KPack and enable existing instances on gfx950 (#1871)
This commit is contained in:
@@ -515,9 +515,16 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -448,8 +448,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
// acc1[m][o] += acc[m][n] * B1[n][o]
|
||||
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -361,10 +361,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto mfma =
|
||||
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
math::lcm(A0K1, B0K1) <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
A0B0B1DataType,
|
||||
is_single_rate_mfma>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
return transform_tensor_descriptor(
|
||||
d0_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
@@ -643,9 +651,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
// acc1[m][o] += acc[m][n] * B1[n][o]
|
||||
|
||||
// sanity check
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(A0K1, B0K1),
|
||||
MfmaSelector<A0B0B1DataType, Gemm0MPerXdl, Gemm0NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_A0K1_B0K1,
|
||||
MfmaSelector<A0B0B1DataType,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
A0B0B1DataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -343,10 +343,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto mfma = MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N4 = mfma.num_input_blks;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
math::lcm(AK1, BK1) <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr auto mfma =
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N4 = mfma.num_input_blks;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
return transform_tensor_descriptor(
|
||||
d0_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(
|
||||
@@ -552,8 +558,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
// acc1[m][o] += acc[m][n] * B1[n][o]
|
||||
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -469,8 +469,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
// acc1[m][o] += acc[m][n] * B1[n][o]
|
||||
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -498,8 +498,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -464,8 +464,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -599,9 +599,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -451,8 +451,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -581,9 +581,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
@@ -1006,9 +1013,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -595,9 +595,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -79,9 +79,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Executable file → Normal file
11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
Executable file → Normal file
@@ -139,9 +139,16 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
|
||||
@@ -869,9 +869,16 @@ struct GridwiseGemm_xdl_cshuffle_v2
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
// BlockSize,
|
||||
|
||||
@@ -147,9 +147,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -155,9 +155,16 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number),
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -1424,7 +1431,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
|
||||
// b scale
|
||||
// static_assert(KPerBlock <= ScaleBlockK);
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{};
|
||||
static constexpr auto mfma =
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
@@ -1895,7 +1903,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
KPerBlock);
|
||||
|
||||
// B scale
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl>{};
|
||||
static constexpr auto mfma =
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>{};
|
||||
static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
|
||||
static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
|
||||
static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
|
||||
|
||||
@@ -489,8 +489,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -487,9 +487,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
|
||||
else if(TileMathThreadGroup::IsBelong())
|
||||
{
|
||||
// branch early for math wave
|
||||
constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1, BK1),
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
TileMathThreadGroupSize,
|
||||
|
||||
@@ -446,8 +446,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
|
||||
// register
|
||||
// sanity check
|
||||
constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
|
||||
constexpr bool is_single_rate_mfma =
|
||||
((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4)
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t k_pack = math::max(
|
||||
math::lcm(AK1, BK1), MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk);
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
|
||||
Reference in New Issue
Block a user