mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Split the instances by architecture. (#1223)
* parse examples inside the add_example_executable function * fix the example 64 cmake file * add xdl flag to the gemm_bias_softmax_gemm_permute example * add filtering of tests based on architecture type * enable test_grouped_gemm for gfx9 only * enable test_transpose only for gfx9 * only linnk test_transpose if it gets built * split the gemm instances by architectures * split gemm_bilinear,grouped_conv_bwd_weight instances by targets * split instances by architecture * split grouped_conv instances by architecture * fix clang format * fix the if-else logic in group_conv headers * small fix for grouped convolution instances * fix the grouped conv bwd weight dl instances * fix client examples * only enable client examples 3 and 4 on gfx9 * set the gfx9 macro * make sure the architecture macros are set by cmake * use separate set of xdl/wmma flags for host code * sinmplify the main cmake file * add conv_fwd_bf8 instance declaration
This commit is contained in:
@@ -12,398 +12,21 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
#include "gemm_dl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "gemm_wmma.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_XDL
|
||||
#include "gemm_xdl.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
@@ -435,16 +58,137 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // DL_KERNELS
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -452,10 +196,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -463,10 +203,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -474,10 +210,6 @@ struct DeviceOperationInstanceFactory<
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
@@ -490,57 +222,25 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
/// add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -578,37 +278,21 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(op_ptrs);
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
@@ -658,6 +342,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
@@ -69,7 +69,7 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -159,7 +159,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -189,7 +189,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
|
||||
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_FP32)
|
||||
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,238 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
|
||||
DeviceGemm<Col, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F64, F64, F64, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_default_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v1_interwave_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_kn_mn_v2_padded_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F8, F8, F8, PassThrough, PassThrough, PassThrough>>>& instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F8, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,439 +10,18 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#include "grouped_convolution_backward_data_xdl.inc"
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "grouped_convolution_backward_data_wmma.inc"
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -488,9 +67,10 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
@@ -500,43 +80,28 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
@@ -544,45 +109,29 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
@@ -593,35 +142,144 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -638,46 +296,16 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, bf8_t> &&
|
||||
is_same_v<ComputeTypeB, f8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32> && is_same_v<ComputeTypeA, F32> &&
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> && is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
op_ptrs);
|
||||
@@ -687,6 +315,7 @@ struct DeviceOperationInstanceFactory<
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,243 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
GKXC,
|
||||
NWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,114 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,228 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,480 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,357 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv1d forward, GNWC/GKXC/GNWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<1,
|
||||
GNWC,
|
||||
GKXC,
|
||||
Empty_Tuple,
|
||||
GNWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F8,
|
||||
F8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGK,
|
||||
BF8,
|
||||
BF8,
|
||||
Empty_Tuple,
|
||||
F8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
BF8>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user