Use new mfma instructions for FP8 on gfx950 (#2202)

* Add logic to use new mfma instructions for fp8 bf8

* Fix example_gemm_xdl_fp8_pk_i4_bpreshuffle_v3 on gfx950 and run clang format

* Update include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp

Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

* Fix intrin_mfma f8 calls due to merge mistake

---------

Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>

[ROCm/composable_kernel commit: f18170064d]
This commit is contained in:
jefyang1
2025-05-19 17:29:51 -07:00
committed by GitHub
parent 8c6eb1c0b8
commit 23a8bed9af
34 changed files with 548 additions and 180 deletions

View File

@@ -32,6 +32,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// this instance has been tested working on gfx950
// < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 128, 32, 32, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::

View File

@@ -124,7 +124,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::I1;
using Base::I2;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
using Base::a_block_desc_m0_m1_m2_k;
@@ -145,6 +144,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::MWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;

View File

@@ -519,13 +519,19 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType,
MPerXdl,
NPerXdl,
ABDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -452,13 +452,16 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,

View File

@@ -365,16 +365,20 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4) ||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8) ||
((is_same<A0B0B1DataType, f8_t>::value || is_same<A0B0B1DataType, bf8_t>::value) &&
lcm_A0K1_B0K1 < 32))
? true
: false;
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
constexpr auto is_scale_mfma = false;
constexpr auto mfma = MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
is_single_rate_mfma,
is_scale_mfma>::selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N5 = mfma.group_size;
return transform_tensor_descriptor(
d0_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(
@@ -657,16 +661,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<A0B0B1DataType, half_t>::value || is_same<A0B0B1DataType, bhalf_t>::value) &&
lcm_A0K1_B0K1 <= 4) ||
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8))
(is_same<A0B0B1DataType, int8_t>::value && lcm_A0K1_B0K1 <= 8) ||
((is_same<A0B0B1DataType, f8_t>::value || is_same<A0B0B1DataType, bf8_t>::value) &&
lcm_A0K1_B0K1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_A0K1_B0K1,
MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_A0K1_B0K1,
MfmaSelector<A0B0B1DataType,
Gemm0MPerXdl,
Gemm0NPerXdl,
A0B0B1DataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm0 = BlockwiseGemmXdlops_v2<
BlockSize,

View File

@@ -347,11 +347,15 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr auto is_scale_mfma = false;
constexpr auto mfma =
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma;
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma;
constexpr auto N3 = mfma.num_groups_per_blk;
constexpr auto N4 = mfma.num_input_blks;
constexpr auto N5 = mfma.group_size;
@@ -564,13 +568,16 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,

View File

@@ -473,13 +473,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize,

View File

@@ -502,13 +502,16 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -679,17 +679,19 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
static constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -468,13 +468,16 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -647,17 +647,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -605,17 +605,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
(((is_same<AComputeDataType, half_t>::value ||
is_same<AComputeDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<AComputeDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeDataType, f8_t>::value || is_same<AComputeDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr auto is_scale_mfma = false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma>::selected_mfma.k_per_blk);
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeDataType,
MPerXdl,
NPerXdl,
BComputeDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -603,13 +603,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<AComputeType, half_t>::value || is_same<AComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<AComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<AComputeType, f8_t>::value || is_same<AComputeType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<AComputeType, MPerXdl, NPerXdl, AComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<AComputeType,
MPerXdl,
NPerXdl,
AComputeType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -455,13 +455,16 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -585,13 +585,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType,
MPerXdl,
NPerXdl,
ABDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
@@ -1018,13 +1024,19 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType,
MPerXdl,
NPerXdl,
ABDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -599,13 +599,19 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<ComputeType, half_t>::value || is_same<ComputeType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeType, f8_t>::value || is_same<ComputeType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeType, MPerXdl, NPerXdl, ComputeType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ComputeType,
MPerXdl,
NPerXdl,
ComputeType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -83,13 +83,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -144,13 +144,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
__host__ static auto CalculateMPadded(index_t M)

View File

@@ -814,13 +814,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -873,13 +873,19 @@ struct GridwiseGemm_xdl_cshuffle_v2
constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,

View File

@@ -255,13 +255,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -148,13 +148,21 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma = true;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;

View File

@@ -160,13 +160,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeA, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -198,13 +198,20 @@ struct GridwiseGemm_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -183,14 +183,20 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -153,14 +153,20 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB, is_single_rate_mfma>::
selected_mfma.k_per_blk);
MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeB,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
using ThisThreadBlock = ThisThreadBlock<BlockSize>;

View File

@@ -164,12 +164,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KPackPerGroup = KPack / KGroup;
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
static constexpr index_t NLane = NPerXdl;

View File

@@ -493,13 +493,16 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,

View File

@@ -491,13 +491,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle
constexpr bool is_single_rate_mfma =
(((is_same<ABDataType, half_t>::value || is_same<ABDataType, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<ABDataType, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ABDataType, f8_t>::value || is_same<ABDataType, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
lcm_AK1_BK1,
MfmaSelector<ABDataType, MPerXdl, NPerXdl, ABDataType, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack =
math::max(lcm_AK1_BK1,
MfmaSelector<ABDataType,
MPerXdl,
NPerXdl,
ABDataType,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<
TileMathThreadGroupSize,

View File

@@ -744,14 +744,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr bool is_single_rate_mfma =
(((is_same<FloatAAdjusted, half_t>::value || is_same<FloatAAdjusted, bhalf_t>::value) &&
K1 <= 4) ||
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8))
(is_same<FloatAAdjusted, int8_t>::value && K1 <= 8) ||
((is_same<FloatAAdjusted, f8_t>::value || is_same<FloatAAdjusted, bf8_t>::value) &&
K1 < 32))
? true
: false;
constexpr index_t KPack = math::max(
K1,
MfmaSelector<FloatAAdjusted, MPerXDL, NPerXDL, FloatBAdjusted, is_single_rate_mfma>::
selected_mfma.k_per_blk);
constexpr auto is_scale_mfma = false;
constexpr index_t KPack = math::max(K1,
MfmaSelector<FloatAAdjusted,
MPerXDL,
NPerXDL,
FloatBAdjusted,
is_single_rate_mfma,
is_scale_mfma>::selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,

View File

@@ -450,13 +450,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr bool is_single_rate_mfma =
(((is_same<FloatAB, half_t>::value || is_same<FloatAB, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8))
(is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<FloatAB, f8_t>::value || is_same<FloatAB, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
constexpr index_t k_pack = math::max(
constexpr auto is_scale_mfma = false;
constexpr index_t k_pack = math::max(
lcm_AK1_BK1,
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma>::selected_mfma
.k_per_blk);
MfmaSelector<FloatAB, MPerXdl, NPerXdl, FloatAB, is_single_rate_mfma, is_scale_mfma>::
selected_mfma.k_per_blk);
auto blockwise_gemm =
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1<BlockSize,

View File

@@ -183,14 +183,27 @@ struct GridwiseMoeGemm
static constexpr index_t NumDTensor = DsDataType::Size();
using mfma_selector = MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>;
static constexpr index_t KPack =
math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk);
static constexpr index_t KLane =
mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KRepeat = KPerBlock / KLane / KPack;
static constexpr index_t NLane = NPerXdl;
static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave;
// static constexpr index_t NumTokens = 1;
static constexpr index_t SortedTileSize = MPerBlock;

View File

@@ -1117,12 +1117,31 @@ struct MfmaSelector
#endif
}
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
template <>
constexpr auto GetMfma<f8_t, 32, 32>()
constexpr auto GetMfma<f8_t, 32, 32, pk_i4_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
#else
return MfmaInstr::mfma_f32_32x32x16f8f8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
{
@@ -1136,11 +1155,21 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<f8_t, 16, 16>()
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32f8f8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
{
@@ -1166,41 +1195,101 @@ struct MfmaSelector
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32>()
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16>()
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
#else
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
#endif
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
#else
return MfmaInstr::mfma_f32_32x32x16f8bf8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32f8bf8;
#endif
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x64f8f6f4;
#else
return MfmaInstr::mfma_f32_32x32x16bf8f8;
#endif
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x128f8f6f4;
#else
return MfmaInstr::mfma_f32_16x16x32bf8f8;
#endif
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
MPerXdlops,
NPerXdlops,
@@ -1530,15 +1619,23 @@ struct XdlopsGemm
return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
}
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto mfma = MfmaSelector < base_type, MPerXdlops, NPerXdlops, additional_type,
(((is_same<base_type, half_t>::value ||
is_same<base_type, bhalf_t>::value) &&
KPack <= 4) ||
(is_same<base_type, int8_t>::value && KPack <= 8))
? true
: false,
is_scale_mfma > {};
// Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
// when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
// except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
static constexpr bool is_single_rate_mfma =
(((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) &&
KPack <= 4) ||
(is_same<base_type, int8_t>::value && KPack <= 8) ||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
is_same<additional_type, pk_i4_t>::value)
? true
: false;
static constexpr auto mfma = MfmaSelector<base_type,
MPerXdlops,
NPerXdlops,
additional_type,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr auto mfma_instr = mfma.selected_mfma;

View File

@@ -533,6 +533,50 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
@@ -1118,6 +1162,52 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0,
0,
0,
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{