mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Optimized GEMMs for MX FP4/8 (#2294)
Adds V3 GEMM pipeline for MX FP4 and MX FP8 Adds V3 GEMM pipeline for MX FP4 with preshuffling Adds MXFP4 GEMM tests (#2275) Adds MXFP4 GEMM examples Adds MXFP4 GEMMs to ckProfiler Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
233e274077
commit
00247e3c29
@@ -173,18 +173,34 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
|
||||
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
{
|
||||
// A matrix in LDS memory, destination of blockwise copy.
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
|
||||
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
|
||||
// setting
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
|
||||
make_tuple(AK1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
|
||||
{
|
||||
// B matrix in LDS memory, destination of blockwise copy.
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
|
||||
// setting
|
||||
return make_naive_tensor_descriptor_packed(
|
||||
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
|
||||
make_tuple(BK1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
@@ -566,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<AK0PerBlock, MPerBlock, AK1>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ADataType,
|
||||
AComputeDataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
2,
|
||||
ABlockTransferScalarPerVector>(
|
||||
@@ -582,10 +600,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<BK0PerBlock, NPerBlock, BK1>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BDataType,
|
||||
BComputeDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
2,
|
||||
BBlockTransferScalarPerVector>(
|
||||
|
||||
@@ -256,8 +256,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
KPerBlock < 128 && MPerXdl == 16))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -184,8 +184,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
|
||||
lcm_AK1_BK1 <= 4) ||
|
||||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
|
||||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
|
||||
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
|
||||
lcm_AK1_BK1 < 32))
|
||||
KPerBlock < 128 && MPerXdl == 16))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
|
||||
@@ -173,15 +173,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
|
||||
lcm_AK1_BK1 < 32))
|
||||
? true
|
||||
: false;
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
static constexpr auto is_scale_mfma = false;
|
||||
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
ComputeTypeA,
|
||||
is_single_rate_mfma,
|
||||
is_scale_mfma>{};
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
|
||||
static constexpr index_t KGroup = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
|
||||
// On gfx950, we have a mfma that required 32 f8 elements as input,
|
||||
// splited into 2 groups of 16 f8 elements.
|
||||
// the 2 groups is not contiguous in the B preshuffed layout.
|
||||
// and we do not want it to be contiguous in the B preshuffled layout
|
||||
// because a memory instruction can only read 16 f8 elements at a time.
|
||||
return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
|
||||
static constexpr index_t KPackPerGroup = KPack / KGroup;
|
||||
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -76,10 +76,12 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
bool BBlockLdsExtraN,
|
||||
@@ -102,9 +104,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
static constexpr auto K1 = Number<K1Value>{};
|
||||
static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
|
||||
static constexpr auto M01 = 1;
|
||||
static constexpr auto N01 = 1;
|
||||
|
||||
static constexpr auto gemm_padder =
|
||||
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
|
||||
@@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
make_tuple(K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
make_tuple(
|
||||
Number<KPerBlock>{} * Number<MPerBlock>{}, K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
@@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
make_tuple(K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_aligned(
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
|
||||
max_lds_align);
|
||||
make_tuple(
|
||||
Number<KPerBlock>{} * Number<NPerBlock>{}, K1, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<1, K0PerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
FloatA,
|
||||
ComputeType,
|
||||
decltype(a_b_k0_m_k1_grid_desc),
|
||||
decltype(a_b_k0_m_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
3,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
@@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
|
||||
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
|
||||
Sequence<1, K0PerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
FloatB,
|
||||
ComputeType,
|
||||
decltype(b_b_k0_n_k1_grid_desc),
|
||||
decltype(b_b_k0_n_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
3,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
|
||||
Reference in New Issue
Block a user