MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)

* Add gemm_mx_fp8_bf8 example with row-major B

* Add more overloads of MX MFMA instructions

* Add MK_KN (RRR) tests

* Add KM_NK (CCR) tests

* Add more problem sizes to Large tests

* Add test_gemm_mx to the list of regression tests
This commit is contained in:
Andriy Roshchenko
2025-05-01 11:55:48 -06:00
committed by GitHub
parent b9d17bdb11
commit 79b0bfeb41
15 changed files with 642 additions and 18 deletions

View File

@@ -45,6 +45,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Row,
Row,
BF8,
e8m0_bexp_t,
F8,
e8m0_bexp_t,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Col,
Col,
Row,
F8,
e8m0_bexp_t,
F8,
e8m0_bexp_t,
BF16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ADataType,
typename AScaleDataType,
typename BDataType,
@@ -93,11 +121,31 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
is_same_v<CDataType, BF16>)
{
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
if constexpr(is_same_v<ADataType, BF8> && is_same_v<BDataType, F8> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
is_same_v<CDataType, BF16>)
{
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(op_ptrs);
}
}