mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit '054f85ab7c0fa07a90968e834899ec415af8b713' into develop
This commit is contained in:
@@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
|
||||
static constexpr index_t AMmaKStride = KPack;
|
||||
static constexpr index_t BMmaKStride = KPack;
|
||||
|
||||
//> store rows/cols into thread registers in chunks of 16
|
||||
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
|
||||
// store rows/cols into thread registers in chunks of 16 for FP8
|
||||
// e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
|
||||
// or in chunks of 32 / APackedSize for FP6/FP4
|
||||
static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
|
||||
|
||||
static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
|
||||
|
||||
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
|
||||
static constexpr index_t KRepeat = KPerThread / KPack;
|
||||
|
||||
@@ -54,6 +54,8 @@ namespace device {
|
||||
*
|
||||
* Conditions for achieving computational load balancing on different hardware platforms can vary.
|
||||
*
|
||||
* \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock)
|
||||
*
|
||||
* Serialized version of the algorithm:
|
||||
* \code
|
||||
* // E = A * B + C
|
||||
@@ -117,7 +119,7 @@ template <typename ALayout,
|
||||
index_t BlockSize, // Thread block size
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t KPerBlock, // multiply with packed_size_v to get the actual KPerBlock
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
|
||||
@@ -419,6 +419,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
|
||||
(GemmSpec != GemmSpecialization::Default &&
|
||||
GemmSpec != GemmSpecialization::MPadding)),
|
||||
"f4x2_pk_t does not support K padding");
|
||||
static_assert(!((is_same_v<remove_cvref_t<ADataType>, f6x16_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, bf6x16_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, f6x32_pk_t> ||
|
||||
is_same_v<remove_cvref_t<ADataType>, bf6x32_pk_t>)&&GemmSpec !=
|
||||
GemmSpecialization::Default),
|
||||
"Packed F6 types do not support padding");
|
||||
|
||||
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding)
|
||||
|
||||
@@ -889,7 +889,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
const ScaleB& scale_b,
|
||||
FloatC& reg_c) const
|
||||
{
|
||||
|
||||
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
|
||||
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
|
||||
}
|
||||
@@ -1224,6 +1223,27 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
|
||||
{
|
||||
@@ -1405,8 +1425,7 @@ struct XdlopsGemm
|
||||
MPerXdlops == 64,
|
||||
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
|
||||
|
||||
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
|
||||
"KPack should be a multiple of k_per_blk");
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
|
||||
Reference in New Issue
Block a user