MX GEMM - New GEMM pipeline for MX data types (#2059)

* Allow selection of mfma_scale instructions

* Read B tensor from LDS to VGPR in chunks of 16 in MFMA order

* Add constexpr and synchronize return type for `get_exponent_value`

* Pass scales by reference and add comments to `mfma_scale_f32_32x32x64`

* Add support for microscaling instructions in `XdlopsGemm`

* Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper

* Remove software implementation of MX GEMM

* Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction

* Update README

* Updated CHANGELOG

* Remove unused static methods
This commit is contained in:
Andriy Roshchenko
2025-04-15 17:17:07 -06:00
committed by GitHub
parent d55c9cb313
commit 7106976a72
19 changed files with 1007 additions and 608 deletions

View File

@@ -694,14 +694,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
static_assert((is_same_v<ADataType, f8_t> || is_same_v<ADataType, bf8_t> ||
is_same_v<ADataType, f6_t> || is_same_v<ADataType, bf6_t> ||
is_same_v<ADataType, f4_t>)&&(is_same_v<BDataType, f8_t> ||
is_same_v<BDataType, bf8_t> ||
is_same_v<BDataType, f6_t> ||
is_same_v<BDataType, bf6_t> ||
is_same_v<BDataType, f4_t>),
static_assert(is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>(),
"Only microscaling formats are supported for ADataType and BDataType");
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
@@ -711,6 +704,11 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(!IsValidCompilationParameter())
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;