mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add splitk gemm fp16 @ fp16 with fp8 compute instances (#983)
* Add ComputeType * Update for compatibility * Add instances * Update profiler api
This commit is contained in:
@@ -98,6 +98,26 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough, F8>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
@@ -105,7 +125,8 @@ template <typename ADataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
typename CLayout,
|
||||
typename ComputeType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -115,7 +136,8 @@ struct DeviceOperationInstanceFactory<
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeType>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmSplitK<ALayout,
|
||||
BLayout,
|
||||
@@ -125,7 +147,8 @@ struct DeviceOperationInstanceFactory<
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ComputeType>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
@@ -158,7 +181,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
is_same_v<CDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
@@ -231,6 +254,31 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t> && is_same_v<ComputeType, f8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user