MX GEMM - Add MX BF8 example (#2071)

* Add MX GEMM example for MX BF8

* Verified MX FP8 with 16x16x128 scale builtin

* Verify MX BF8 GEMM with BF16 output


[ROCm/composable_kernel commit: da54464cce]
This commit is contained in:
Andriy Roshchenko
2025-04-16 15:25:02 -06:00
committed by GitHub
parent bf57a9a5f3
commit 348760d56e
5 changed files with 139 additions and 0 deletions

View File

@@ -699,6 +699,9 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
static_assert(is_same_v<ComputeTypeA, ADataType> && is_same_v<ComputeTypeB, BDataType>,
"ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType");
return true;
}

View File

@@ -1141,6 +1141,12 @@ struct MfmaSelector
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<bf8_t, 32, 32>()
{

View File

@@ -588,6 +588,35 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const bf8x32_t& reg_a,
const int32_t& scale_a,
const bf8x32_t& reg_b,
const int32_t& scale_b,
FloatC& reg_c)
{
#if defined(__gfx950__)
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz
1, // blgp
0, // OPSEL
scale_a,
0, // OPSEL
scale_b);
#else
ignore = reg_a;
ignore = scale_a;
ignore = reg_b;
ignore = scale_b;
ignore = reg_c;
#endif
}
};