mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
MX GEMM - Add MX BF8 example (#2071)
* Add MX GEMM example for MX BF8
* Verified MX FP8 with 16x16x128 scale builtin
* Verify MX BF8 GEMM with BF16 output
[ROCm/composable_kernel commit: da54464cce]
This commit is contained in:
committed by
GitHub
parent
bf57a9a5f3
commit
348760d56e
@@ -699,6 +699,9 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
|
||||
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
|
||||
|
||||
static_assert(is_same_v<ComputeTypeA, ADataType> && is_same_v<ComputeTypeB, BDataType>,
|
||||
"ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -1141,6 +1141,12 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
|
||||
{
|
||||
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
|
||||
@@ -588,6 +588,35 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bf8x32_t& reg_a,
|
||||
const int32_t& scale_a,
|
||||
const bf8x32_t& reg_b,
|
||||
const int32_t& scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
reg_a,
|
||||
reg_b,
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
1, // cbsz
|
||||
1, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user