mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Add fp8 gemm instances (#920)
* Add fp8 gemm instances
* Update instance naming
[ROCm/composable_kernel commit: 94bfa50256]
This commit is contained in:
@@ -312,6 +312,23 @@ void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
#endif
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -505,6 +522,32 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
else if constexpr(is_same_v<ADataType, ck::f8_t> && is_same_v<BDataType, ck::f8_t> &&
|
||||
is_same_v<CDataType, ck::f8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user