added wmma multiply_multiply instances

This commit is contained in:
Zoltan Lakatos
2025-06-17 15:03:18 +00:00
parent 881bc3fbb0
commit ac60286ed0
13 changed files with 360 additions and 60 deletions

View File

@@ -279,10 +279,10 @@ FOREACH(subdir_path ${dir_list})
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
set(add_inst 0)
endif()
# if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
# message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
# set(add_inst 0)
# endif()
if ("${cmake_instance}" MATCHES "gemm_bilinear")
set(add_inst 0)
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9") AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES))

View File

@@ -1,4 +1,4 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
@@ -38,6 +38,9 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp
)
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck