mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Use new mfma instructions for FP8 on gfx950 (#2202)
* Add logic to use new mfma instructions for fp8 bf8
* Fix example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 on gfx950 and run clang format
* Update include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
* Fix intrin_mfma f8 calls due to merge mistake
---------
Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
[ROCm/composable_kernel commit: f18170064d]
This commit is contained in:
@@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// this instance has been tested working on gfx950
|
||||
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
|
||||
@@ -124,7 +124,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
@@ -145,6 +144,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
|
||||
|
||||
using Base::MWaves;
|
||||
|
||||
static constexpr auto xdlops_gemm =
|
||||
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
@@ -519,13 +519,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ABDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -452,13 +452,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -365,16 +365,20 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4) ||
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8) ||
|
||||
((is_same<A0B0B1DataType, f8_t>::value || is_same<A0B0B1DataType, bf8_t>::value) &&
|
||||
lcm_A0K1_B0K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
A0B0B1DataType,
|
||||
is_single_rate_mfma>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
return transform_tensor_descriptor(
|
||||
d0_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
@@ -657,16 +661,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
|
||||
lcm_A0K1_B0K1 <= 4) ||
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
|
||||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8) ||
|
||||
((is_same<A0B0B1DataType, f8_t>::value || is_same<A0B0B1DataType, bf8_t>::value) &&
|
||||
lcm_A0K1_B0K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_A0K1_B0K1,
|
||||
MfmaSelector<A0B0B1DataType,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
A0B0B1DataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_A0K1_B0K1,
|
||||
MfmaSelector<A0B0B1DataType,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
A0B0B1DataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -347,11 +347,15 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr auto mfma =
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma;
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma;
|
||||
constexpr auto N3 = mfma.num_groups_per_blk;
|
||||
constexpr auto N4 = mfma.num_input_blks;
|
||||
constexpr auto N5 = mfma.group_size;
|
||||
@@ -564,13 +568,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -473,13 +473,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
|
||||
BlockSize,
|
||||
|
||||
@@ -502,13 +502,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -679,17 +679,19 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -468,13 +468,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -647,17 +647,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -605,17 +605,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
(((is_same<AComputeDataType, half_t>::value ||
|
||||
is_same<AComputeDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr auto is_scale_mfma = false;
|
||||
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma>::selected_mfma.k_per_blk);
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
BComputeDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -603,13 +603,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<AComputeType, f8_t>::value || is_same<AComputeType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<AComputeType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
AComputeType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -455,13 +455,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -585,13 +585,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ABDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
@@ -1018,13 +1024,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ABDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -599,13 +599,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeType, f8_t>::value || is_same<ComputeType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -83,13 +83,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -144,13 +144,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
|
||||
@@ -814,13 +814,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -873,13 +873,19 @@ struct GridwiseGemm_xdl_cshuffle_v2
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
// BlockSize,
|
||||
|
||||
@@ -255,13 +255,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -148,13 +148,21 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
|
||||
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma = true;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
|
||||
@@ -160,13 +160,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -198,13 +198,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -183,14 +183,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -153,14 +153,20 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeB,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
|
||||
@@ -164,12 +164,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KPackPerGroup = KPack / KGroup;
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
|
||||
@@ -493,13 +493,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
|
||||
BlockSize,
|
||||
|
||||
@@ -491,13 +491,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t KPack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack =
|
||||
math::max(lcm_AK1_BK1,
|
||||
MfmaSelector<ABDataType,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ABDataType,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
|
||||
TileMathThreadGroupSize,
|
||||
|
||||
@@ -744,14 +744,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
|
||||
K1 <= 4) ||
|
||||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8))
|
||||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8) ||
|
||||
((is_same<FloatAAdjusted, f8_t>::value || is_same<FloatAAdjusted, bf8_t>::value) &&
|
||||
K1 < 32))
|
||||
? true
|
||||
: false;
|
||||
|
||||
constexpr index_t KPack = math::max(
|
||||
K1,
|
||||
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted, is_single_rate_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t KPack = math::max(K1,
|
||||
MfmaSelector<FloatAAdjusted,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
FloatBAdjusted,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>::selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
|
||||
@@ -450,13 +450,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
|
||||
constexpr bool is_single_rate_mfma =
|
||||
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
|
||||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
constexpr index_t k_pack = math::max(
|
||||
constexpr auto is_scale_mfma = false;
|
||||
constexpr index_t k_pack = math::max(
|
||||
lcm_AK1_BK1,
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
|
||||
.k_per_blk);
|
||||
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
|
||||
selected_mfma.k_per_blk);
|
||||
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,
|
||||
|
||||
@@ -183,14 +183,27 @@ struct GridwiseMoeGemm
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
|
||||
static constexpr index_t KPack =
|
||||
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane =
|
||||
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
|
||||
static constexpr index_t NLane = NPerXdl;
|
||||
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
|
||||
// static constexpr index_t NumTokens = 1;
|
||||
static constexpr index_t SortedTileSize = MPerBlock;
|
||||
|
||||
|
||||
@@ -1117,12 +1117,31 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
|
||||
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32, pk_i4_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
|
||||
{
|
||||
@@ -1136,11 +1155,21 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
|
||||
{
|
||||
@@ -1166,41 +1195,101 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x16f8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
#endif
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
|
||||
MPerXdlops,
|
||||
NPerXdlops,
|
||||
@@ -1530,15 +1619,23 @@ struct XdlopsGemm
|
||||
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
|
||||
}
|
||||
|
||||
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
|
||||
static constexpr auto mfma = MfmaSelector < base_type, MPerXdlops, NPerXdlops, additional_type,
|
||||
(((is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, bhalf_t>::value) &&
|
||||
KPack <= 4) ||
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8))
|
||||
? true
|
||||
: false,
|
||||
is_scale_mfma > {};
|
||||
// Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
|
||||
// when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
|
||||
// except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
|
||||
static constexpr bool is_single_rate_mfma =
|
||||
(((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) &&
|
||||
KPack <= 4) ||
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8) ||
|
||||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
|
||||
is_same<additional_type, pk_i4_t>::value)
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto mfma = MfmaSelector<base_type,
|
||||
MPerXdlops,
|
||||
NPerXdlops,
|
||||
additional_type,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
|
||||
@@ -533,6 +533,50 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
@@ -1118,6 +1162,52 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
0, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
|
||||
1, // blgp
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = reg_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user