mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 08:00:13 +00:00
MX GEMM - Add FP6 GEMM Test (#2488)
* Add F6 GEMM MX Test * Add BF6 GEMM MX Test
This commit is contained in:
committed by
GitHub
parent
518dc21ae8
commit
25b359d630
@@ -24,6 +24,8 @@ using F8 = ck::f8_t;
|
||||
using BF8 = ck::bf8_t;
|
||||
using I4 = ck::pk_i4_t;
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F6 = ck::f6x16_pk_t;
|
||||
using BF6 = ck::bf6x16_pk_t;
|
||||
|
||||
using E8M0 = ck::e8m0_bexp_t;
|
||||
using E8M0PK = int32_t;
|
||||
|
||||
@@ -87,6 +87,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F6,
|
||||
E8M0PK,
|
||||
F16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMX<Row,
|
||||
Col,
|
||||
Row,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF6,
|
||||
E8M0PK,
|
||||
BF16,
|
||||
32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
template <typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
@@ -130,6 +158,8 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> && is_same_v<CLayout, Row>)
|
||||
{
|
||||
// Row-Col-Row -- one of the two currently supported layouts, another one is
|
||||
// Row-MFMA-Row
|
||||
if constexpr(is_same_v<ADataType, F8> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
@@ -147,6 +177,16 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, F6> && is_same_v<BDataType, F6> &&
|
||||
is_same_v<CDataType, F16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, BF6> && is_same_v<BDataType, BF6> &&
|
||||
is_same_v<CDataType, BF16>)
|
||||
{
|
||||
add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
|
||||
Reference in New Issue
Block a user