Wmma support for multiple Ds based GEMMs (#2613)

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"

* Fixed cmake build errors related to test_fp8

* Updates to support mixed precision


(cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip


(cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F8xF16xF16 to gemm_wmma_universal


(cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xF8xF16 to gemm_wmma_universal

* Added support for BF16xI4xBF16 to gemm_wmma_universal


(cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Added support for F16xI4xF16 to gemm_wmma_universal

* Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType

* Added missing test class for FP16_KM_NK

* Pre-commit hooks fixes

* Added padding instances for f16xf16xf16

* Fixed cmake errors related to  gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8"


(cherry picked from commit 5bdc993dbf)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Fixed cmake build errors related to test_fp8


(cherry picked from commit 12176616b6)

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

* Ammending changes for adding support for padding instances for f16xf16xf16

* Fixes for padding instances for f16xf16xf16

* Added padding instances for bf16xbf16, f8xf8

* Added packed instances for bf16xi4xbf16

* Added padding instances for f8xf16xf16

* Added padding instances for f16xf8xf16, f16xi4xf16

* Fixed typos for bf16xbf16xbf16 padding instances

* Fixed typos for padded instances

* Added tests for fp16, KM_KN and KM_NK

* Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances.

* Fixed typos

* Updated the set of tests for FP16

* Updated the set of tests for FP16

* Fix typo

* Moved f16xi4 test under the correct data layout group

* example for gemm_universal_bf16

* Adding examples for gemm_wmma instances

* Added the  missing parameters

* Fixed review comments and added executable to cmakeLists

* Fixing clang format

* Fixing build erros

* Fixed compilation failure.

* Modified some code as per gemm_universal_examples

* Fixed the gemm specialization error

* Fixed the build errors.

* Fix strides of a/b_thread_desc

The descriptors are larger than needed (even though the compiler don't alloc registers for unused values).

* Load in M/NRepeat dims with thread copy's slice instead of a loop

* Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation

* Implement Intrawave and Interwave variants of pipeline v1

* Add instances for Interwave and Intrawave v1

* Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0

* Remove instances that are too slow (mostly because of register spilling)

* Add a workaround for fp8/bf8->f32 packed conversion issue

* Add instances for Interwave and Intrawave v1

* Enable profiling of mixed precision with f8 and int4 on WMMA

* Fix segfault in profiler when B is pk_i4_t

b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds.

* Remove instances that are too slow (mostly because of register spilling)

* Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations

* Add test case for bf16_i4

* Add missing Regular tests

* Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS

They take more than 30 seconds

* Fix a bug that fp16_i4 validation passes only with PermuteB

A permutation required by conversion from pk_i4_t to half_t does not
depend on PermuteB, they can be used independently.

* Use PermuteB with f16_i4 in most instances (as xdl)

Some instances use PermuteB = false for checking correctness.
See also the previous commit.

* Fix cache flushing for pk_i4

* Add mixed precision examples

* Disable all tests and instances with f8 on gfx11

Even though f8_f16 and f16_f8 don't require f8 WMMA instructions,
gfx11 still lacks hardware instructions for fast f8->f32 conversion.

* Add FP16 KM_NK and KM_KN test suites for XDL

These tests were added to common .inc for better testing of WMMA instances

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* removed unnecessary ck parts from compilation

* initial gemm_add_multiply instance implementations

* fixed profiler help message for gemm_add_multiply

* improved multiply_add profiler layout help

* fixed template arguments for test instances

* added test for gemm_add_multiply

* Support multiple D in GridwiseGemm_wmma_cshuffle_v3

DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters.

* Use ThreadGroupTensorSliceTransfer_v7r3

* Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support

* Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma

* Implement DeviceGemmMultipleD_Wmma_CShuffleV3

* Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3

* Prepare gemma_add tests for adding wmma

* Add gemm_add_fastgelu instances and test

* Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API

ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use
DeviceGemmMultipleDSplitK instances there.

* switched to splitK interface

* log print added to splitk benchmarks

* revert main cmake comments

* newline change reverted

* added add_fastgelu instances

* revert unintended change in xdl add_fastgelu

* created gemm_add_add_fastgelu instances

* created fastegelu instances

* added tests for all splitk fastgelus

* Added tests.

* multiply_add instances created

* updates to add_multiply splitk instances

* splitk xdl test fixes

* added wmma multiply_multiply instances

* fixed ONLY_XDL_AND_WMMA_KERNELS tag

* Added gemm_add examples for wmma v1 and v3

* fixed / workarounded i8 instances

* Modified the v3 code to added one fp16 bxdl instance.

* added bf16 xdl instance.

* adding gemm_add wmma_cshuffle and other support


(cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* add instances into camkelists


(cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* This is work in progress, edited the template parameters in order to build

(cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype


(cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* added datatype and use clang-format-12


(cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec)

Co-authored-by: Cenxuan <cenxuan@streamhpc.com>

* Fixing build errors

* Added instances for v3

* Adding instances and executables

* Code update of template parameters modified.

* Renamed file.

* Added tests.

* resolved error tests.

* Fixing build errors

* Updated comments

* removed the changes as per the MR review comment.

* Updated tests.

* fp8 instances - not tested

* Restored the Cmake file that was reverted by mistake during rebase.

* fixed wmma_op test

* Updated comments.

* Updated the template parameter description

* fixed rdna4 instances

* fixed back compatibility on gfx11

* cleanups

* fix ckProfiler

* one more cmake fix

* added fp8 instances

* Updated tests to ad BF16 instances as per review comment

* Added include file and cleaned up(as per review comment)

* Updated and optimized the example code for all types.

* Fixed clang format

* Resolve "Implement `device_gemm_bilinear` for RDNA4"

* test generalization to handle FP16 shuffle better

* added missing changes

* Added bf16 wmma instance for add_relu

* Added f16 wmma instance and corrected bf16 instance errors.

* Added instances to Cmake

* Modified the template parameters to make the instances work.

* Fixed typo in profiler

* Added v3 instances for gemm_add_relu

* addressed core review comments

* Added test for gemm_add_relu wmma instance

* Cleaned up the code.

* Added examples for gemm_add_relu

* Fixing typo to resolve build errors.

* Fixes applied to fix  the precision loss.

* fix billinear test after merge

* Removed the old wmma instances.

* Added wrapper and renamed the wmma_v3 instances

* Updated copyrights and added wrappers.

* Fixes applied according to review comments

* Apply 1 suggestion(s) to 1 file(s)

Co-authored-by: Robin Voetter <robin@streamhpc.com>

* Removed the old wmma instances.

* Updated wrapper for the v3 instances

* removed the old wmma examples

* Renamed the v3 instances

* Deleted the  gtest file added by mistake.

* Updated thge profiler with wrapper

* Fixed test errors.

* Fixed the review comments

* Fixed the if condition MACROS.

* REVERTED THE PROFILER CHANGES

* Revert "REVERTED THE PROFILER CHANGES"

This reverts commit 21cb98546c.

* Revert "Fixed test errors."

This reverts commit 13efcc6fe1.

* Revert "Updated thge profiler with wrapper"

This reverts commit 536f86661d.

* Added missing wrapper instances

* Updated copyrights.

* Fixed typo.

* Fixed copyrights.

* Updated copyrights.

* updated copyrights.

* comments on the atomics workaround

* fixed cmake comment

* Fix bug from merge

* clang-format-18

* Fix compilation error

* Fix linking error

* Fix bug in add and add_relu examples

* Fix error including file (typo)

* Quick fix to compile examples for different targets

* Fix for multi target

* implemented f16 and bf16 instances for gemm_silu

* addressed review comments

* addressed review comments

* Fix clang format

* Fix clang format

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
Co-authored-by: apoorva <apoorva@streamhpc.com>
Co-authored-by: Anton Gorenko <anton@streamhpc.com>
Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com>
Co-authored-by: Cenxuan <cenxuan@streamhpc.com>
Co-authored-by: Robin Voetter <robin@streamhpc.com>
Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: b740380906]
This commit is contained in:
Enrico Degregori
2025-09-05 16:31:08 +02:00
committed by GitHub
parent 40361182ca
commit 74a885c5c0
112 changed files with 7877 additions and 715 deletions

View File

@@ -75,7 +75,7 @@ function(add_instance_library INSTANCE_NAME)
endif()
# Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
if(NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_")
message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -117,13 +117,13 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950
#only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 and gfx1200/gfx1201
if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)
if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
endif()
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
@@ -136,7 +136,7 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
endif()
if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx10-3-generic gfx11-generic)
endif()
if(source_name MATCHES "gemm_universal_preshuffle" AND source_name MATCHES "f8")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
@@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx11|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
set(add_inst 0)
endif()

View File

@@ -1,5 +1,8 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_instance
device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, Add, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Add>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,10 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_add_fastgelu_instance
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// elementwise(c, d0, d1) = fastgelu(c + d0 + d1)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddAddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddAddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,9 +1,14 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_fastgelu_instance
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,77 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0)
// elementwise(c, d0) = fastgelu(c + d0)
// output: e[m, n]
// input: a[m, k], b[n, k], d0[m, n]
template <GemmSpecialization GemmSpec>
using device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddFastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddFastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,11 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_multiply_instance
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,78 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
using device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, AddMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
AddMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,8 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_relu_instance
device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instance.cpp
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, BF16, BF16, BF16_Tuple, BF16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,70 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
template <GemmSpecialization GemmSpec>
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
using device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F32, PassThrough, PassThrough, AddRelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddRelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,6 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_add_silu_instance
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,109 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise((a * b), d0, d1)
// outout: e[m, n]
// input: a[m, k], b[k, n], d0[m, n], d1[m, n]
template <GemmSpecialization GemmSpec, typename T_dataType, typename T_tupleDataType>
using device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, T_dataType, T_dataType, T_tupleDataType, T_dataType, F32, F32, PassThrough, PassThrough, AddSilu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
AddSilu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault,
F16,
F16_Tuple>{});
add_device_operation_instances(
instances,
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding,
F16,
F16_Tuple>{});
}
// implement bf16 with same parameters to avoid duplication
void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
AddSilu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault,
BF16,
BF16_Tuple>{});
add_device_operation_instances(
instances,
device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmMNKPadding,
BF16,
BF16_Tuple>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -4,6 +4,10 @@ add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp

View File

@@ -0,0 +1,72 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
template <GemmSpecialization GemmSpec>
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
template <GemmSpecialization GemmSpec>
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,77 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
template <GemmSpecialization GemmSpec>
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e[m, n] = bilinear(a[k, m] * b[k, n], d[m, n])
template <GemmSpecialization GemmSpec>
using device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Tuple, Row, F16, F16, F16_Tuple, F16, F32, F16, PassThrough, PassThrough, Bilinear, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Tuple,
Row,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances<
GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -45,7 +45,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
// M/N/K padding
// N % 16 == 0 && K % 16 == 0
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
@@ -55,7 +55,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
// M/N/K padding
// N % 8 == 0 && K % 8 == 0
//################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
@@ -65,7 +65,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
// M/N/K padding
// N % 8 == 0 && K % 8 == 0
@@ -76,7 +75,6 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>,
// M/N/K padding
// N % 1 == 0 && K % 8 == 0
@@ -86,8 +84,7 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>,
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1>
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>
// clang-format on
>;

View File

@@ -1,5 +1,10 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_fastgelu_instance
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp

View File

@@ -0,0 +1,75 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instance<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,77 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,80 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,83 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
// e = elementwise(a * b)
// elementwise(c) = fastgelu(c)
// output: e[m, n]
// input: a[m, k], b[n, k]
template <GemmSpecialization GemmSpec>
using device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, F16, F16, Empty_Tuple, F16, F32, F32, PassThrough, PassThrough, FastGelu, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
FastGelu>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,7 +1,11 @@
# ONLY XDL_KERNELS
set(GEMM_MULTIPLY_ADD_INSTANCES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
# ONLY XDL_AND_WMMA_KERNELS
add_instance_library(device_gemm_multiply_add_instance
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp
)

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
using device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F16, F16_F16_Tuple, F16, F32, F16, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F16,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
using device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V3>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Row_Tuple, Row, F16, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyAdd, GemmMNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3>
// clang-format on
>;
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Row_Tuple,
Row,
F16,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyAdd>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,4 +1,4 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
@@ -38,6 +38,11 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_mk_nk_mn.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_mk_nk_mn.cpp
)
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, BF16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
F8,
F8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_f8_f8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, F8, F8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, F8, F8, F32_F32_Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, F8, F8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
F8,
F8,
F32_F32_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_f8_f8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,71 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| Compute| Compute|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| TypeA| TypeB|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<1, 1, 1>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<1, 1, 1>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck