mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
universal streamk fp8 changes (#1665)
* universal streamk fp8 changes & ckprofiler instances
* revert strides to -1 and verification options
* fp8 exclusion on pre-gfx94 for universal_streamk
* PR review based revisions: permissions reverted, removed hip err checks
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
[ROCm/composable_kernel commit: d6d4c2788b]
This commit is contained in:
committed by
GitHub
parent
326639e80c
commit
5a5bfe14f4
@@ -237,6 +237,206 @@ void add_device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_mem_v2_mnkpaddin
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm_Streamk_V2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
@@ -327,6 +527,121 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemm_S
|
||||
}
|
||||
#endif
|
||||
|
||||
#if(defined(CK_ENABLE_FP8))
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user