mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Add gemm universal bf16 instances (#1484)
* revert ckprofiler change * temp save * Add test and test pass * test pass * Fix bug inside rotating buffer when tensor is not packed * bug fix * clang format --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
@@ -335,6 +335,105 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_insta
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(
|
||||
@@ -618,6 +717,58 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_mkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v1_mkpadding_instances(
|
||||
op_ptrs);
|
||||
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_mem_v2_mkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
|
||||
|
||||
Reference in New Issue
Block a user