mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 14:54:47 +00:00
Wmma support for multiple Ds based GEMMs (#2613)
* Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" * Fixed cmake build errors related to test_fp8 * Updates to support mixed precision (cherry picked from commit e65d71180393e7b66169c56565a6bac740427de6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Adding support for RRR, F8xF16xF16 gemm_universal_wmma - wip (cherry picked from commit f8c06322df0abcbd5945a56cdf5bffe56480f9f0) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F8xF16xF16 to gemm_wmma_universal (cherry picked from commit 15c851de6daa513a12c2e3af299bab0176175fb5) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xF8xF16 to gemm_wmma_universal * Added support for BF16xI4xBF16 to gemm_wmma_universal (cherry picked from commit c6a4a69d2d43d59bae8bdabfae80d648646f217e) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Added support for F16xI4xF16 to gemm_wmma_universal * Fixed IsSupportedArgument to check ComputeTypeA, ComputeTypeB instead of ADataType, BDataType * Added missing test class for FP16_KM_NK * Pre-commit hooks fixes * Added padding instances for f16xf16xf16 * Fixed cmake errors related to gemm_bilinear. Previously, if the above flags are set, cmake build fails: GPU_TARGETS="gfx1100;gfx1201" -D DTYPES="fp16;bf16;fp8" (cherry picked from commit5bdc993dbf) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Fixed cmake build errors related to test_fp8 (cherry picked from commit12176616b6) Co-authored-by: Anca Hamuraru <anca@streamhpc.com> * Ammending changes for adding support for padding instances for f16xf16xf16 * Fixes for padding instances for f16xf16xf16 * Added padding instances for bf16xbf16, f8xf8 * Added packed instances for bf16xi4xbf16 * Added padding instances for f8xf16xf16 * Added padding instances for f16xf8xf16, f16xi4xf16 * Fixed typos for bf16xbf16xbf16 padding instances * Fixed typos for padded instances * Added tests for fp16, KM_KN and KM_NK * Padding not supported for when BDataType is pk_i4_t. Added fix for correct check and removed padding instances. * Fixed typos * Updated the set of tests for FP16 * Updated the set of tests for FP16 * Fix typo * Moved f16xi4 test under the correct data layout group * example for gemm_universal_bf16 * Adding examples for gemm_wmma instances * Added the missing parameters * Fixed review comments and added executable to cmakeLists * Fixing clang format * Fixing build erros * Fixed compilation failure. * Modified some code as per gemm_universal_examples * Fixed the gemm specialization error * Fixed the build errors. * Fix strides of a/b_thread_desc The descriptors are larger than needed (even though the compiler don't alloc registers for unused values). * Load in M/NRepeat dims with thread copy's slice instead of a loop * Clone BlockwiseGemmXdlops_pipeline_v1 for WMMA implementation * Implement Intrawave and Interwave variants of pipeline v1 * Add instances for Interwave and Intrawave v1 * Add instances with ABlockLdsExtraM and BBlockLdsExtraN = 0 * Remove instances that are too slow (mostly because of register spilling) * Add a workaround for fp8/bf8->f32 packed conversion issue * Add instances for Interwave and Intrawave v1 * Enable profiling of mixed precision with f8 and int4 on WMMA * Fix segfault in profiler when B is pk_i4_t b_device_buf's size in bytes is larger than b_k_n_permute so b_device_buf.ToDevice reads out-of-bounds. * Remove instances that are too slow (mostly because of register spilling) * Add missing add_device_gemm_wmma_universal_f8_f8_bf16 declarations * Add test case for bf16_i4 * Add missing Regular tests * Add test_gemm_universal_xdl/wmma_fp16 to REGRESSION_TESTS They take more than 30 seconds * Fix a bug that fp16_i4 validation passes only with PermuteB A permutation required by conversion from pk_i4_t to half_t does not depend on PermuteB, they can be used independently. * Use PermuteB with f16_i4 in most instances (as xdl) Some instances use PermuteB = false for checking correctness. See also the previous commit. * Fix cache flushing for pk_i4 * Add mixed precision examples * Disable all tests and instances with f8 on gfx11 Even though f8_f16 and f16_f8 don't require f8 WMMA instructions, gfx11 still lacks hardware instructions for fast f8->f32 conversion. * Add FP16 KM_NK and KM_KN test suites for XDL These tests were added to common .inc for better testing of WMMA instances * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * removed unnecessary ck parts from compilation * initial gemm_add_multiply instance implementations * fixed profiler help message for gemm_add_multiply * improved multiply_add profiler layout help * fixed template arguments for test instances * added test for gemm_add_multiply * Support multiple D in GridwiseGemm_wmma_cshuffle_v3 DeviceGemm_Wmma_CShuffleV3 is changed for new template parameters. * Use ThreadGroupTensorSliceTransfer_v7r3 * Clone for device_gemm_wmma_cshuffle_v3.hpp for future Multiple D support * Clone example/65_gemm_multiply_multiply/gemm_add_add_xdl_fp16.cpp for wmma * Implement DeviceGemmMultipleD_Wmma_CShuffleV3 * Make gemm_add_add_wmma to work with DeviceGemmMultipleD_Wmma_CShuffleV3 * Prepare gemma_add tests for adding wmma * Add gemm_add_fastgelu instances and test * Add a special wrapper to use DeviceGemmMultipleD_Wmma_CShuffleV3 with old API ckProfiler uses DeviceGemmMultipleD (tests also call its functions), the wrapper allows to use DeviceGemmMultipleDSplitK instances there. * switched to splitK interface * log print added to splitk benchmarks * revert main cmake comments * newline change reverted * added add_fastgelu instances * revert unintended change in xdl add_fastgelu * created gemm_add_add_fastgelu instances * created fastegelu instances * added tests for all splitk fastgelus * Added tests. * multiply_add instances created * updates to add_multiply splitk instances * splitk xdl test fixes * added wmma multiply_multiply instances * fixed ONLY_XDL_AND_WMMA_KERNELS tag * Added gemm_add examples for wmma v1 and v3 * fixed / workarounded i8 instances * Modified the v3 code to added one fp16 bxdl instance. * added bf16 xdl instance. * adding gemm_add wmma_cshuffle and other support (cherry picked from commit ec447e7f564095ea969eddc39ec77b843aa52976) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * add instances into camkelists (cherry picked from commit 23bf2d2771c939ea3ca7f493433c55255bffd08e) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * This is work in progress, edited the template parameters in order to build (cherry picked from commit b4fde8a3314cb44659c4bbda35f1a0133c63dc41) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * temp work saved, changed the BDataType to f16 or bf16 since wmma currently not support non-equal A and B datatype (cherry picked from commit 22fbd68f1db458ab50780a394ee2544c7a1484d1) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * added datatype and use clang-format-12 (cherry picked from commit ae4e853682ef1bb27784b2f965b4a66b3751ceec) Co-authored-by: Cenxuan <cenxuan@streamhpc.com> * Fixing build errors * Added instances for v3 * Adding instances and executables * Code update of template parameters modified. * Renamed file. * Added tests. * resolved error tests. * Fixing build errors * Updated comments * removed the changes as per the MR review comment. * Updated tests. * fp8 instances - not tested * Restored the Cmake file that was reverted by mistake during rebase. * fixed wmma_op test * Updated comments. * Updated the template parameter description * fixed rdna4 instances * fixed back compatibility on gfx11 * cleanups * fix ckProfiler * one more cmake fix * added fp8 instances * Updated tests to ad BF16 instances as per review comment * Added include file and cleaned up(as per review comment) * Updated and optimized the example code for all types. * Fixed clang format * Resolve "Implement `device_gemm_bilinear` for RDNA4" * test generalization to handle FP16 shuffle better * added missing changes * Added bf16 wmma instance for add_relu * Added f16 wmma instance and corrected bf16 instance errors. * Added instances to Cmake * Modified the template parameters to make the instances work. * Fixed typo in profiler * Added v3 instances for gemm_add_relu * addressed core review comments * Added test for gemm_add_relu wmma instance * Cleaned up the code. * Added examples for gemm_add_relu * Fixing typo to resolve build errors. * Fixes applied to fix the precision loss. * fix billinear test after merge * Removed the old wmma instances. * Added wrapper and renamed the wmma_v3 instances * Updated copyrights and added wrappers. * Fixes applied according to review comments * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Robin Voetter <robin@streamhpc.com> * Removed the old wmma instances. * Updated wrapper for the v3 instances * removed the old wmma examples * Renamed the v3 instances * Deleted the gtest file added by mistake. * Updated thge profiler with wrapper * Fixed test errors. * Fixed the review comments * Fixed the if condition MACROS. * REVERTED THE PROFILER CHANGES * Revert "REVERTED THE PROFILER CHANGES" This reverts commit21cb98546c. * Revert "Fixed test errors." This reverts commit13efcc6fe1. * Revert "Updated thge profiler with wrapper" This reverts commit536f86661d. * Added missing wrapper instances * Updated copyrights. * Fixed typo. * Fixed copyrights. * Updated copyrights. * updated copyrights. * comments on the atomics workaround * fixed cmake comment * Fix bug from merge * clang-format-18 * Fix compilation error * Fix linking error * Fix bug in add and add_relu examples * Fix error including file (typo) * Quick fix to compile examples for different targets * Fix for multi target * implemented f16 and bf16 instances for gemm_silu * addressed review comments * addressed review comments * Fix clang format * Fix clang format --------- Co-authored-by: Anca Hamuraru <anca@streamhpc.com> Co-authored-by: apoorva <apoorva@streamhpc.com> Co-authored-by: Anton Gorenko <anton@streamhpc.com> Co-authored-by: Zoltan Lakatos <zoltan.lakatos@streamhpc.com> Co-authored-by: Cenxuan <cenxuan@streamhpc.com> Co-authored-by: Robin Voetter <robin@streamhpc.com> Co-authored-by: Kiefer van Teutem <kiefer.van.teutem@streamhpc.com> Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit:b740380906]
This commit is contained in:
@@ -54,6 +54,7 @@ using MFMA = ck::tensor_layout::gemm::MFMA;
|
||||
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
using Row_Row_Tuple = ck::Tuple<Row, Row>;
|
||||
using Row_Col_Tuple = ck::Tuple<Row, Col>;
|
||||
|
||||
// Conv layout
|
||||
//
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -41,8 +42,37 @@ void add_device_gemm_add_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add +
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>&);
|
||||
|
||||
void add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -52,17 +82,91 @@ template <typename ALayout,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with Add at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
|
||||
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -80,6 +184,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
@@ -104,10 +209,32 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Add>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -11,11 +11,13 @@
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -67,8 +69,64 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
// GEMM + Add + Add + FastGelu
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>>&);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleDSplitK specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -79,18 +137,100 @@ template <typename ALayout,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddAddFastGelu>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
constexpr bool IsAllDRowLayout = is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row>;
|
||||
constexpr bool IsAllDFloat16 =
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t>;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t> && IsAllDRowLayout && IsAllDFloat16)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + Add + FastGelu
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -100,47 +240,69 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddAddFastGelu>;
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
constexpr bool IsAllDRowLayout = is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row>;
|
||||
constexpr bool IsAllDFloat16 =
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t>;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
is_same_v<EDataType, half_t> && IsAllDRowLayout && IsAllDFloat16)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddAddFastGelu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -150,3 +312,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif // CK_ENABLE_FP16
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -93,8 +94,64 @@ void add_device_gemm_add_fastgelu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
|
||||
void add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>>&);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleDSplitK specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -103,18 +160,97 @@ template <typename ALayout,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_fastgelu_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -132,7 +268,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -143,7 +280,7 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16 && CK_ENABLE_INT8
|
||||
|
||||
#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8)
|
||||
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, int8_t> &&
|
||||
@@ -156,8 +293,9 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_BF16 && CK_ENABLE_INT8
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -186,6 +324,29 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddFastGelu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -70,6 +71,145 @@ void add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>&);
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>&);
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>&);
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>&);
|
||||
|
||||
void add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Multiply
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddMultiply>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::AddMultiply>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_multiply_wmma_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + Multiply
|
||||
template <typename ALayout,
|
||||
@@ -111,6 +251,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
@@ -144,6 +285,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddMultiply>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_relu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -41,6 +42,35 @@ void add_device_gemm_add_relu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
#endif
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
|
||||
void add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>>&);
|
||||
#endif
|
||||
|
||||
// GEMM + Add + Relu
|
||||
template <typename ALayout,
|
||||
@@ -52,17 +82,92 @@ template <typename ALayout,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddRelu at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
|
||||
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_relu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + Relu
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -80,6 +185,7 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
@@ -106,10 +212,32 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddRelu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_add_silu_xdl_c_shuffle_f16_i8_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -41,6 +42,107 @@ void add_device_gemm_add_silu_xdl_c_shuffle_bf16_i8_bf16_bf16_mk_kn_mn_mn_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>>&);
|
||||
|
||||
void add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>>&);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Add + Silu
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// no split-k xdl implementations
|
||||
#endif // CL_USE_XDL
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_silu_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#if defined(CK_ENABLE_BF16)
|
||||
if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
|
||||
is_same_v<D0DataType, ck::bhalf_t> && is_same_v<EDataType, ck::bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_add_silu_wmma_c_shuffle_bf16_bf16_bf16_bf16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_USE_WMMA
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Add + Silu
|
||||
template <typename ALayout,
|
||||
@@ -78,8 +180,11 @@ struct DeviceOperationInstanceFactory<
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
@@ -105,7 +210,28 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CL_USE_XDL
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddSilu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
|
||||
#endif // CK_USE_WMMA
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,7 +16,8 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
@@ -68,8 +69,11 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -121,7 +125,63 @@ void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
#endif // CK_ENABLE_INT8
|
||||
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Bilinear
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -131,18 +191,95 @@ template <typename ALayout,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Bilinear>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddBilinear at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + Bilinear
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -152,14 +289,15 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::Bilinear>;
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#if defined(CK_ENABLE_FP16) && defined(CK_USE_XDL)
|
||||
#if defined(CK_USE_XDL)
|
||||
#if defined(CK_ENABLE_FP16)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -188,8 +326,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_ENABLE_INT8) && defined(CK_USE_WMMA)
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
|
||||
// Bilinear wmma i8 instances are using DeviceGemmMultipleD interface.
|
||||
#if defined(CK_ENABLE_INT8)
|
||||
if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
|
||||
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
|
||||
{
|
||||
@@ -214,7 +375,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_INT8
|
||||
#endif // CK_USE_WMMA
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -67,6 +68,132 @@ void add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>&);
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>&);
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>&);
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Row,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>&);
|
||||
|
||||
void add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
|
||||
Col,
|
||||
Empty_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>>&);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Add + FastGelu
|
||||
// DeviceGemmMultipleDSplitK specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with AddFastGelu at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_fastgelu_wmma_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// GEMM + FastGelu
|
||||
template <typename ALayout,
|
||||
@@ -75,17 +202,17 @@ template <typename ALayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -103,6 +230,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
{
|
||||
@@ -127,6 +255,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
Empty_Tuple,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
Empty_Tuple,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
FastGelu>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
@@ -136,4 +286,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#if defined(CK_USE_XDL)
|
||||
void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
@@ -71,9 +72,64 @@ void add_device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_m
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP8
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#if defined(CK_USE_WMMA)
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Row,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
|
||||
void add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Row_Tuple,
|
||||
Row,
|
||||
F16,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>>&);
|
||||
#endif // CK_USE_WMMA_FP8
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
// GEMM + Multiply + Add
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
@@ -84,18 +140,107 @@ template <typename ALayout,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyAdd>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
// No XDL instances for DeviceGemmMultipleDSplitK with MultiplyAdd at the moment
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F16> &&
|
||||
is_same_v<D0DataType, F16> && is_same_v<D1DataType, F16> &&
|
||||
is_same_v<EDataType, F16>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
#ifdef CK_USE_WMMA_FP8
|
||||
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<D0DataType, F32> && is_same_v<D1DataType, F32> &&
|
||||
is_same_v<EDataType, F16>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_add_wmma_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
// DeviceGemmMultipleD specialization
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename D0Layout,
|
||||
typename D1Layout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename D0DataType,
|
||||
typename D1DataType,
|
||||
typename EDataType>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>>
|
||||
{
|
||||
using DeviceOp = DeviceGemmMultipleD<ALayout,
|
||||
BLayout,
|
||||
@@ -105,17 +250,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyAdd>;
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F16> &&
|
||||
is_same_v<D0DataType, F16> && is_same_v<D1DataType, F16> &&
|
||||
is_same_v<EDataType, F16>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
@@ -133,10 +279,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<D0DataType, float> && is_same_v<D1DataType, float> &&
|
||||
is_same_v<EDataType, half_t>)
|
||||
#ifdef CK_ENABLE_FP8
|
||||
if constexpr(is_same_v<ADataType, F16> && is_same_v<BDataType, F8> &&
|
||||
is_same_v<D0DataType, F32> && is_same_v<D1DataType, F32> &&
|
||||
is_same_v<EDataType, F16>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<D0Layout, Row> && is_same_v<D1Layout, Row> &&
|
||||
@@ -153,7 +299,29 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP8
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
// Reuse DeviceGemmMultipleDSplitK instances
|
||||
using Wrapper = DeviceGemmMultipleDSplitKWrapper<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<D0Layout, D1Layout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<D0DataType, D1DataType>,
|
||||
EDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyAdd>;
|
||||
auto new_op_ptrs =
|
||||
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
|
||||
for(auto& op_ptr : new_op_ptrs)
|
||||
{
|
||||
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP8
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
|
||||
@@ -199,7 +200,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_i
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
#endif
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
@@ -278,8 +279,8 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_ENABLE_FP8
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
|
||||
@@ -463,7 +464,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16
|
||||
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
|
||||
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
|
||||
@@ -544,7 +545,62 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_in
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16 || CK_ENABLE_INT8
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
F16_F16_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
F32_F32_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
|
||||
void add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
|
||||
Col,
|
||||
Row_Col_Tuple,
|
||||
Row,
|
||||
F8,
|
||||
F8,
|
||||
F32_F32_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>>& instances);
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -553,36 +609,35 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>>
|
||||
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::MultiplyMultiply>;
|
||||
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
|
||||
BLayout,
|
||||
Tuple<Row, Col>,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
MultiplyMultiply>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP8
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
@@ -624,7 +679,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_ENABLE_BF16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
@@ -665,8 +720,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_ENABLE_FP8
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
|
||||
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
@@ -691,6 +746,51 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_f16_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
|
||||
is_same_v<CDataType, bhalf_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<CLayout, Row>)
|
||||
{
|
||||
add_device_gemm_multiply_multiply_wmma_c_shuffle_f8_f8_bf16_km_nk_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user