WMMA GEMM universal pipeline v1, mixed precision and paddings, examples (#2230)

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Updates to support mixed precision

* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip

* Added support for F8xF16xF16 to gemm_wmma_universal

* Added support for F16xF8xF16 to gemm_wmma_universal

* Added support for BF16xI4xBF16 to gemm_wmma_universal

* Added support for F16xI4xF16 to gemm_wmma_universal

* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType

* Added missing test class for FP16_KM_NK

* Pre-commit hooks fixes

* Added padding instances for f16xf16xf16

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Ammending changes for adding support for padding instances for f16xf16xf16

* Fixes for padding instances for f16xf16xf16

* Added padding instances for bf16xbf16, f8xf8

* Added packed instances for bf16xi4xbf16

* Added padding instances for f8xf16xf16

* Added padding instances for f16xf8xf16, f16xi4xf16

* Fixed typos for bf16xbf16xbf16 padding instances

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Fix int8 DTYPES check for gemm_bilinear

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: Apoorva Kalyani <apoorva@streamhpc.com>

[ROCm/composable_kernel commit: 52b4860a30]
This commit is contained in:
Anton Gorenko
2025-06-04 12:22:33 +06:00
committed by GitHub
parent c395db8926
commit 780cb29a42
117 changed files with 4953 additions and 271 deletions

View File

@@ -64,21 +64,45 @@ struct DeviceOperationInstanceFactory<
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
}
#endif
@@ -91,28 +115,52 @@ struct DeviceOperationInstanceFactory<
{
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
@@ -120,11 +168,144 @@ struct DeviceOperationInstanceFactory<
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_bf16_i4_bf16_mk_nk_mn_comp_default_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_bf16_i4_bf16_km_nk_mn_comp_default_instances(
op_ptrs);
}
}
#endif
#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_mnkpadding_instances(
op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_default_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_kpadding_instances(op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_mnpadding_instances(
op_ptrs);
add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_mnkpadding_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_i4_f16_mk_nk_mn_comp_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_wmma_universal_f16_i4_f16_km_nk_mn_comp_default_instances(op_ptrs);
}
}
#endif

View File

@@ -13,55 +13,355 @@ void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8))
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F8, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_i4_bf16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_i4_bf16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_i4_bf16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_i4_bf16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_bf16_i4_bf16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, BF16, I4, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(CK_USE_WMMA_FP8))
void add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_kpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_mnpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_mnkpadding_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_i4_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Row, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_wmma_universal_f16_i4_f16_km_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<
DeviceGemmV2<Col, Col, Row, F16, I4, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation