mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Merge commit 'b7403809062654e6a0e54abc2623d3f8f9fd0288' into develop
This commit is contained in:
@@ -75,7 +75,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endif()
|
||||
# Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
|
||||
message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -117,13 +117,13 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
|
||||
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950
|
||||
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201
|
||||
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
|
||||
if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
@@ -136,7 +136,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
|
||||
endif()
|
||||
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
@@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
message(DEBUG "Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
|
||||
message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_instance
|
||||
device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// outout: e[m, n]
|
||||
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
|
||||
using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// outout: e[m, n]
|
||||
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
|
||||
using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,10 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_add_fastgelu_instance
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,9 +1,14 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_fastgelu_instance
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0)
|
||||
// elementwise(c, d0) = fastgelu(c + d0)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,79 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0)
|
||||
// elementwise(c, d0) = fastgelu(c + d0)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,82 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0)
|
||||
// elementwise(c, d0) = fastgelu(c + d0)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,85 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0)
|
||||
// elementwise(c, d0) = fastgelu(c + d0)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k], d0[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,11 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_multiply_instance
|
||||
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,78 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,8 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_relu_instance
|
||||
device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// outout: e[m, n]
|
||||
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
|
||||
using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
|
||||
GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// outout: e[m, n]
|
||||
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
|
||||
using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,6 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_add_silu_instance
|
||||
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise((a * b), d0, d1)
|
||||
// outout: e[m, n]
|
||||
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
|
||||
|
||||
template <GemmSpecialization GemmSpec, typename T_dataType, typename T_tupleDataType>
|
||||
using device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault,
|
||||
F16,
|
||||
F16_Tuple>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding,
|
||||
F16,
|
||||
F16_Tuple>{});
|
||||
}
|
||||
|
||||
// implement bf16 with same parameters to avoid duplication
|
||||
void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault,
|
||||
BF16,
|
||||
BF16_Tuple>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding,
|
||||
BF16,
|
||||
BF16_Tuple>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -4,6 +4,10 @@ add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
|
||||
GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -45,7 +45,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 16 == 0 && K % 16 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
@@ -55,7 +55,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 8 == 0 && K % 8 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
@@ -65,7 +65,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 8 == 0 && K % 8 == 0
|
||||
@@ -76,7 +75,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 1 == 0 && K % 8 == 0
|
||||
@@ -86,8 +84,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1>
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_fastgelu_instance
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise(a * b)
|
||||
// elementwise(c) = fastgelu(c)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,77 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise(a * b)
|
||||
// elementwise(c) = fastgelu(c)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise(a * b)
|
||||
// elementwise(c) = fastgelu(c)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,83 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
// e = elementwise(a * b)
|
||||
// elementwise(c) = fastgelu(c)
|
||||
// output: e[m, n]
|
||||
// input: a[m, k], b[n, k]
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,7 +1,11 @@
|
||||
# ONLY XDL_KERNELS
|
||||
set(GEMM_MULTIPLY_ADD_INSTANCES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_gemm_multiply_add_instance
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
|
||||
using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,4 +1,4 @@
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
|
||||
|
||||
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
|
||||
@@ -38,6 +38,11 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
|
||||
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp
|
||||
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
|
||||
|
||||
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp
|
||||
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp
|
||||
device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp
|
||||
device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp
|
||||
)
|
||||
|
||||
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
F32_F32_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,71 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
|
||||
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
|
||||
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
|
||||
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmDefault>{});
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user