Merge commit '054f85ab7c0fa07a90968e834899ec415af8b713' into develop

This commit is contained in:
assistant-librarian[bot]
2025-07-07 17:07:08 +00:00
parent 7a78fb644d
commit f8ee69963d
18 changed files with 578 additions and 95 deletions

View File

@@ -66,9 +66,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
// store rows/cols into thread registers in chunks of 16 for FP8
// e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
// or in chunks of 32 / APackedSize for FP6/FP4
static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize;
static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now");
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;

View File

@@ -54,6 +54,8 @@ namespace device {
*
* Conditions for achieving computational load balancing on different hardware platforms can vary.
*
* \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock)
*
* Serialized version of the algorithm:
* \code
* // E = A * B + C
@@ -117,7 +119,7 @@ template <typename ALayout,
index_t BlockSize, // Thread block size
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t KPerBlock, // multiply with packed_size_v to get the actual KPerBlock
index_t AK1,
index_t BK1,
index_t MPerXDL,

View File

@@ -419,6 +419,12 @@ struct GridwiseGemmMX_xdl_cshuffle_v3
(GemmSpec != GemmSpecialization::Default &&
GemmSpec != GemmSpecialization::MPadding)),
"f4x2_pk_t does not support K padding");
static_assert(!((is_same_v<remove_cvref_t<ADataType>, f6x16_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, bf6x16_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, f6x32_pk_t> ||
is_same_v<remove_cvref_t<ADataType>, bf6x32_pk_t>)&&GemmSpec !=
GemmSpecialization::Default),
"Packed F6 types do not support padding");
if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding)

View File

@@ -889,7 +889,6 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
@@ -1224,6 +1223,27 @@ struct MfmaSelector
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
{
@@ -1405,8 +1425,7 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B