mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
MX GEMM - Add FP8 GEMM Tests for Different Layouts (#2152)
* Add gemm_mx_fp8_bf8 example with row-major B * Add more overloads of MX MFMA instructions * Add MK_KN (RRR) tests * Add KM_NK (CCR) tests * Add more problem sizes to Large tests * Add test_gemm_mx to the list of regression tests
This commit is contained in:
committed by
GitHub
parent
b9d17bdb11
commit
79b0bfeb41
@@ -45,6 +45,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Row,
|
||||
Row,
|
||||
BF8,
|
||||
e8m0_bexp_t,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Col,
|
||||
Col,
|
||||
Row,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
F8,
|
||||
e8m0_bexp_t,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
@@ -93,11 +121,31 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, BF8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user