From 876e9540e433572c9b138ca0de76b5e30b988430 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:58:23 -0700 Subject: [PATCH] Revert "Add support for mixed precision in contraction scale and bilinear" (#967) * Revert "Add support for mixed precision in contraction scale and bilinear (#936)" This reverts commit ba712aee9a3823cbc44383780761783af0934160. * revert commits #957 and #960 [ROCm/composable_kernel commit: 4daedf8ca56f3bd93481708bd9d762045839ec20] --- example/26_contraction/CMakeLists.txt | 48 - example/26_contraction/common_instances.hpp | 183 ---- ...raction_bilinear_xdl_bf16_compute_fp32.cpp | 86 -- ...raction_bilinear_xdl_fp16_compute_fp32.cpp | 86 -- .../contraction_bilinear_xdl_fp32.cpp | 319 +++++-- ...raction_bilinear_xdl_fp32_compute_bf16.cpp | 86 -- ...raction_bilinear_xdl_fp32_compute_fp16.cpp | 86 -- .../contraction_bilinear_xdl_fp64.cpp | 319 +++++-- ...raction_bilinear_xdl_fp64_compute_fp32.cpp | 86 -- ...ontraction_scale_xdl_bf16_compute_fp32.cpp | 85 -- ...ontraction_scale_xdl_fp16_compute_fp32.cpp | 85 -- .../contraction_scale_xdl_fp32.cpp | 303 +++++-- ...ontraction_scale_xdl_fp32_compute_bf16.cpp | 85 -- ...ontraction_scale_xdl_fp32_compute_fp16.cpp | 85 -- .../contraction_scale_xdl_fp64.cpp | 303 +++++-- ...ontraction_scale_xdl_fp64_compute_fp32.cpp | 85 -- .../run_contraction_bilinear_example.inc | 234 ----- .../run_contraction_scale_example.inc | 217 ----- example/60_gemm_multi_ABD/CMakeLists.txt | 10 - .../gemm_multi_ABD_xdl_fp16.cpp | 362 -------- .../61_contraction_multi_ABD/CMakeLists.txt | 10 - .../contraction_multi_ABD_xdl_fp16.cpp | 329 ------- .../device_contraction_multiple_abd.hpp | 61 -- .../device/device_contraction_multiple_d.hpp | 3 +- ..._contraction_multiple_abd_xdl_cshuffle.hpp | 846 ------------------ ...ce_contraction_multiple_d_xdl_cshuffle.hpp | 6 +- .../element/binary_element_wise_operation.hpp | 19 - .../element/unary_element_wise_operation.hpp | 26 - ...ridwise_gemm_multiple_abd_xdl_cshuffle.hpp | 3 +- .../cpu/reference_contraction.hpp | 16 +- .../device_operation_instance_factory.hpp | 2 - .../device_contraction_instance.hpp | 292 ------ .../gpu/contraction_bilinear.hpp | 424 +-------- .../gpu/contraction_scale.hpp | 420 +-------- .../gpu/contraction_bilinear/CMakeLists.txt | 44 +- ...16_bf16_bf16_compute_f32_kknn_instance.cpp | 57 -- ...16_bf16_bf16_compute_f32_knnn_instance.cpp | 57 -- ...16_bf16_bf16_compute_f32_mknn_instance.cpp | 57 -- ...16_bf16_bf16_compute_f32_mnnn_instance.cpp | 57 -- ..._f16_f16_f16_compute_f32_kknn_instance.cpp | 57 -- ..._f16_f16_f16_compute_f32_knnn_instance.cpp | 57 -- ..._f16_f16_f16_compute_f32_mknn_instance.cpp | 57 -- ..._f16_f16_f16_compute_f32_mnnn_instance.cpp | 57 -- ...f32_f32_f32_compute_bf16_kknn_instance.cpp | 57 -- ...f32_f32_f32_compute_bf16_knnn_instance.cpp | 57 -- ...f32_f32_f32_compute_bf16_mknn_instance.cpp | 57 -- ...f32_f32_f32_compute_bf16_mnnn_instance.cpp | 57 -- ..._f32_f32_f32_compute_f16_kknn_instance.cpp | 57 -- ..._f32_f32_f32_compute_f16_knnn_instance.cpp | 57 -- ..._f32_f32_f32_compute_f16_mknn_instance.cpp | 57 -- ..._f32_f32_f32_compute_f16_mnnn_instance.cpp | 57 -- ..._shuffle_f32_f32_f32_f32_kknn_instance.cpp | 52 +- ..._shuffle_f32_f32_f32_f32_knnn_instance.cpp | 55 +- ..._shuffle_f32_f32_f32_f32_mknn_instance.cpp | 55 +- ..._shuffle_f32_f32_f32_f32_mnnn_instance.cpp | 55 +- ..._f64_f64_f64_compute_f32_kknn_instance.cpp | 57 -- ..._f64_f64_f64_compute_f32_knnn_instance.cpp | 57 -- ..._f64_f64_f64_compute_f32_mknn_instance.cpp | 57 -- ..._f64_f64_f64_compute_f32_mnnn_instance.cpp | 57 -- ..._shuffle_f64_f64_f64_f64_kknn_instance.cpp | 49 +- ..._shuffle_f64_f64_f64_f64_knnn_instance.cpp | 49 +- ..._shuffle_f64_f64_f64_f64_mknn_instance.cpp | 49 +- ..._shuffle_f64_f64_f64_f64_mnnn_instance.cpp | 49 +- .../gpu/contraction_scale/CMakeLists.txt | 32 +- ...f16_bf16_bf16_compute_f32_kkn_instance.cpp | 57 -- ...f16_bf16_bf16_compute_f32_knn_instance.cpp | 57 -- ...f16_bf16_bf16_compute_f32_mkn_instance.cpp | 57 -- ...f16_bf16_bf16_compute_f32_mnn_instance.cpp | 57 -- ...e_f16_f16_f16_compute_f32_kkn_instance.cpp | 57 -- ...e_f16_f16_f16_compute_f32_knn_instance.cpp | 57 -- ...e_f16_f16_f16_compute_f32_mkn_instance.cpp | 57 -- ...e_f16_f16_f16_compute_f32_mnn_instance.cpp | 57 -- ..._f32_f32_f32_compute_bf16_kkn_instance.cpp | 57 -- ..._f32_f32_f32_compute_bf16_knn_instance.cpp | 57 -- ..._f32_f32_f32_compute_bf16_mkn_instance.cpp | 57 -- ..._f32_f32_f32_compute_bf16_mnn_instance.cpp | 57 -- ...e_f32_f32_f32_compute_f16_kkn_instance.cpp | 57 -- ...e_f32_f32_f32_compute_f16_knn_instance.cpp | 57 -- ...e_f32_f32_f32_compute_f16_mkn_instance.cpp | 57 -- ...e_f32_f32_f32_compute_f16_mnn_instance.cpp | 57 -- ...xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp | 56 +- ...xdl_c_shuffle_f32_f32_f32_knn_instance.cpp | 59 +- ...xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp | 59 +- ...xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp | 59 +- ...e_f64_f64_f64_compute_f32_kkn_instance.cpp | 57 -- ...e_f64_f64_f64_compute_f32_knn_instance.cpp | 57 -- ...e_f64_f64_f64_compute_f32_mkn_instance.cpp | 57 -- ...e_f64_f64_f64_compute_f32_mnn_instance.cpp | 57 -- ...xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp | 49 +- ...xdl_c_shuffle_f64_f64_f64_knn_instance.cpp | 49 +- ...xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp | 49 +- ...xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp | 49 +- profiler/README.md | 24 +- .../profiler/profile_contraction_impl.hpp | 62 +- .../profiler/profile_contraction_utils.hpp | 14 +- profiler/src/profile_contraction_bilinear.cpp | 219 ++--- profiler/src/profile_contraction_scale.cpp | 215 ++--- test/contraction/test_contraction.cpp | 145 ++- .../test_contraction_interface.cpp | 10 +- 99 files changed, 1961 insertions(+), 7536 deletions(-) delete mode 100644 example/26_contraction/common_instances.hpp delete mode 100644 example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp delete mode 100644 example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp delete mode 100644 example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp delete mode 100644 example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp delete mode 100644 example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp delete mode 100644 example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp delete mode 100644 example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp delete mode 100644 example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp delete mode 100644 example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp delete mode 100644 example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp delete mode 100644 example/26_contraction/run_contraction_bilinear_example.inc delete mode 100644 example/26_contraction/run_contraction_scale_example.inc delete mode 100644 example/60_gemm_multi_ABD/CMakeLists.txt delete mode 100644 example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp delete mode 100644 example/61_contraction_multi_ABD/CMakeLists.txt delete mode 100644 example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp delete mode 100644 include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp delete mode 100644 include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp delete mode 100644 library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp diff --git a/example/26_contraction/CMakeLists.txt b/example/26_contraction/CMakeLists.txt index 1a0489ce9c..167a9f147d 100644 --- a/example/26_contraction/CMakeLists.txt +++ b/example/26_contraction/CMakeLists.txt @@ -1,52 +1,4 @@ -add_custom_target(example_contraction) -add_custom_target(example_contraction_scale) -add_custom_target(example_contraction_bilinear) - -# FP32 add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32) - add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32) - -add_example_executable(example_contraction_bilinear_xdl_fp32_compute_bf16 contraction_bilinear_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_bf16) - -add_example_executable(example_contraction_scale_xdl_fp32_compute_bf16 contraction_scale_xdl_fp32_compute_bf16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_bf16) - -add_example_executable(example_contraction_bilinear_xdl_fp32_compute_fp16 contraction_bilinear_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp32_compute_fp16) - -add_example_executable(example_contraction_scale_xdl_fp32_compute_fp16 contraction_scale_xdl_fp32_compute_fp16.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp32_compute_fp16) - -# FP64 add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64) - add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64) - -add_example_executable(example_contraction_bilinear_xdl_fp64_compute_fp32 contraction_bilinear_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp64_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_fp64_compute_fp32 contraction_scale_xdl_fp64_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp64_compute_fp32) - -# FP16 -add_example_executable(example_contraction_bilinear_xdl_fp16_compute_fp32 contraction_bilinear_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_fp16_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_fp16_compute_fp32 contraction_scale_xdl_fp16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_fp16_compute_fp32) - -# BF16 -add_example_executable(example_contraction_bilinear_xdl_bf16_compute_fp32 contraction_bilinear_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_bilinear example_contraction_bilinear_xdl_bf16_compute_fp32) - -add_example_executable(example_contraction_scale_xdl_bf16_compute_fp32 contraction_scale_xdl_bf16_compute_fp32.cpp) -add_dependencies(example_contraction_scale example_contraction_scale_xdl_bf16_compute_fp32) - -add_dependencies(example_contraction example_contraction_scale) -add_dependencies(example_contraction example_contraction_bilinear) diff --git a/example/26_contraction/common_instances.hpp b/example/26_contraction/common_instances.hpp deleted file mode 100644 index eb2fa5ce73..0000000000 --- a/example/26_contraction/common_instances.hpp +++ /dev/null @@ -1,183 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" - -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; - -template -using S = ck::Sequence; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Generic instances for fp32, fp16 and bf16 data types. -// clang-format off -template -using DeviceOpInstanceKK_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceKN_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceMK_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; - -template -using DeviceOpInstanceMN_Generic = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; - -// Fp64 instances. -template -using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; - -template -using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; - -// clang-format on diff --git a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp deleted file mode 100644 index 77ec5669a7..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_bf16_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = BF16; -using DDataType = BF16; -using DsDataType = ck::Tuple; -using EDataType = BF16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp deleted file mode 100644 index 25c4d9eac4..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp16_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F16; -using DDataType = F16; -using DsDataType = ck::Tuple; -using EDataType = F16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp index d440e7394e..78522160c8 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp32.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp32.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F32; using BDataType = F32; @@ -13,7 +32,6 @@ using CShuffleDataType = F32; using DDataType = F32; using DsDataType = ck::Tuple; using EDataType = F32; -using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -23,64 +41,253 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; -#include "run_contraction_bilinear_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp deleted file mode 100644 index 5b748949bd..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_bf16.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F32; -using DsDataType = ck::Tuple; -using EDataType = F32; -using ComputeDataType = BF16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp deleted file mode 100644 index 334cd615d1..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp32_compute_fp16.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F32; -using DsDataType = ck::Tuple; -using EDataType = F32; -using ComputeDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp index 86ea7e0212..6cceed5bc1 100644 --- a/example/26_contraction/contraction_bilinear_xdl_fp64.cpp +++ b/example/26_contraction/contraction_bilinear_xdl_fp64.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F64; using BDataType = F64; @@ -13,7 +32,6 @@ using CShuffleDataType = F64; using DDataType = F64; using DsDataType = ck::Tuple; using EDataType = F64; -using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -23,64 +41,253 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; +// clang-format off +using DeviceOpInstanceKKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; +using DeviceOpInstanceKNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; +using DeviceOpInstanceMKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKNN; -#include "run_contraction_bilinear_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // D[M0, M1, N0, N1] + std::vector d_ms_ns_lengths{30, 128, 32, 64}; + std::vector d_ms_ns_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float alpha = 1.f; + float beta = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 28) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + d_ms_ns_lengths = {M0, M1, N0, N1}; + d_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; + + alpha = std::stof(argv[26]); + beta = std::stof(argv[27]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); + printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg26 to 27: alpha, beta\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + d_device_buf.ToDevice(d_ms_ns.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{alpha, beta}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{d_device_buf.GetDeviceBuffer()}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 1>{d_ms_ns_lengths}, + std::array, 1>{d_ms_ns_strides}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1), + d_ms_ns(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp deleted file mode 100644 index b0a2a9a10f..0000000000 --- a/example/26_contraction/contraction_bilinear_xdl_fp64_compute_fp32.cpp +++ /dev/null @@ -1,86 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F64; -using BDataType = F64; -using AccDataType = F32; -using CShuffleDataType = F64; -using DDataType = F64; -using DsDataType = ck::Tuple; -using EDataType = F64; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; - -using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKNN; - -#include "run_contraction_bilinear_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_bilinear_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp deleted file mode 100644 index 32652fc6f3..0000000000 --- a/example/26_contraction/contraction_scale_xdl_bf16_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = BF16; -using BDataType = BF16; -using AccDataType = F32; -using CShuffleDataType = BF16; -using DsDataType = ck::Tuple<>; -using EDataType = BF16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp deleted file mode 100644 index a4797e2400..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp16_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F16; -using DsDataType = ck::Tuple<>; -using EDataType = F16; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp32.cpp index fe984b29d6..1574f5d18f 100644 --- a/example/26_contraction/contraction_scale_xdl_fp32.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp32.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F32; using BDataType = F32; @@ -12,7 +31,6 @@ using AccDataType = F32; using CShuffleDataType = F32; using DsDataType = ck::Tuple<>; using EDataType = F32; -using ComputeDataType = F32; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -22,64 +40,237 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKN; -#include "run_contraction_scale_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp deleted file mode 100644 index 83c8fec179..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_bf16.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = F32; -using ComputeDataType = BF16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp b/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp deleted file mode 100644 index 27e42e8203..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp32_compute_fp16.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F32; -using BDataType = F32; -using AccDataType = F32; -using CShuffleDataType = F32; -using DsDataType = ck::Tuple<>; -using EDataType = F32; -using ComputeDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_Generic; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_Generic; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_Generic; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_Generic; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/contraction_scale_xdl_fp64.cpp b/example/26_contraction/contraction_scale_xdl_fp64.cpp index 37374780b6..3dacc70887 100644 --- a/example/26_contraction/contraction_scale_xdl_fp64.cpp +++ b/example/26_contraction/contraction_scale_xdl_fp64.cpp @@ -1,10 +1,29 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +#include +#include + #include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "common_instances.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/numeric.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" + +template +using S = ck::Sequence; + +using F64 = double; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = F64; using BDataType = F64; @@ -12,7 +31,6 @@ using AccDataType = F64; using CShuffleDataType = F64; using DsDataType = ck::Tuple<>; using EDataType = F64; -using ComputeDataType = F64; static constexpr ck::index_t NumDimM = 2; static constexpr ck::index_t NumDimN = 2; @@ -22,64 +40,237 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Scale; -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; +// clang-format off +using DeviceOpInstanceKKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; +using DeviceOpInstanceKNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; +using DeviceOpInstanceMKN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1>; + +using DeviceOpInstanceMNN = ck::tensor_operation::device:: + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F64, F64, F64, F64, DsDataType, F64, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>; +// clang-format on using DeviceOpInstance = DeviceOpInstanceKKN; -#include "run_contraction_scale_example.inc" +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } + // A[M0, M1, K0, K1] + std::vector a_ms_ks_lengths{30, 128, 32, 64}; + std::vector a_ms_ks_strides{524288, 4096, 128, 1}; + // B[N0, N1, K0, K1] + std::vector b_ns_ks_lengths{32, 64, 32, 64}; + std::vector b_ns_ks_strides{524288, 4096, 128, 1}; + // E[M0, M1, N0, N1] + std::vector e_ms_ns_lengths{30, 128, 32, 64}; + std::vector e_ms_ns_strides{524288, 4096, 128, 1}; + + float scale = 1.f; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 23) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + const ck::index_t M0 = std::stoi(argv[4]); + const ck::index_t M1 = std::stoi(argv[5]); + + const ck::index_t N0 = std::stoi(argv[6]); + const ck::index_t N1 = std::stoi(argv[7]); + + const ck::index_t K0 = std::stoi(argv[8]); + const ck::index_t K1 = std::stoi(argv[9]); + + a_ms_ks_lengths = {M0, M1, K0, K1}; + a_ms_ks_strides = { + std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; + + b_ns_ks_lengths = {N0, N1, K0, K1}; + b_ns_ks_strides = { + std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; + + e_ms_ns_lengths = {M0, M1, N0, N1}; + e_ms_ns_strides = { + std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; + + scale = std::stof(argv[22]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); + printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); + printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); + printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); + printf("arg22: scale\n"); + exit(0); + } + + Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); + Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); + Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); + + std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; + std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; + std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_ms_ks.mData.data()); + b_device_buf.ToDevice(b_ns_ks.mData.data()); + + // set zero + e_device_buf.SetZero(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{scale}; + + // device operation + auto op = DeviceOpInstance{}; + auto invoker = op.MakeInvoker(); + auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + std::array, 0>{}, + std::array, 0>{}, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + + if(!op.IsSupportedArgument(argument)) + { + std::cout << op.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + ck::index_t M = + ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); + + ck::index_t N = ck::accumulate_n( + e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); + + ck::index_t K = ck::accumulate_n( + a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); + + if(do_verification) + { + Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); + + using ReferenceOpInstance = + ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; + + auto ref_op = ReferenceOpInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + + Tensor empty_tensor(std::vector{}, std::vector{}); + auto ref_argument = + ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); + + ref_invoker.Run(ref_argument); + + for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) + { + for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) + { + for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) + { + for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) + { + cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), + c_ms_ns_host_result(m0, m1, n0, n1)); + } + } + } + } + + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp b/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp deleted file mode 100644 index 5dbeb23154..0000000000 --- a/example/26_contraction/contraction_scale_xdl_fp64_compute_fp32.cpp +++ /dev/null @@ -1,85 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "common_instances.hpp" - -using ADataType = F64; -using BDataType = F64; -using AccDataType = F32; -using CShuffleDataType = F64; -using DsDataType = ck::Tuple<>; -using EDataType = F64; -using ComputeDataType = F32; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::Scale; - -using DeviceOpInstanceKKN = DeviceOpInstanceKK_FP64; - -using DeviceOpInstanceKNN = DeviceOpInstanceKN_FP64; - -using DeviceOpInstanceMKN = DeviceOpInstanceMK_FP64; - -using DeviceOpInstanceMNN = DeviceOpInstanceMN_FP64; - -using DeviceOpInstance = DeviceOpInstanceKKN; - -#include "run_contraction_scale_example.inc" - -int main(int argc, char* argv[]) { return run_contraction_scale_example(argc, argv); } diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc deleted file mode 100644 index 7b9a01e0ae..0000000000 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ /dev/null @@ -1,234 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -int run_contraction_bilinear_example(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float alpha = 1.f; - float beta = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 28) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - d_ms_ns_lengths = {M0, M1, N0, N1}; - d_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])}; - - alpha = std::stof(argv[26]); - beta = std::stof(argv[27]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n"); - printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg26 to 27: alpha, beta\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + - sizeof(DDataType) * M * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc deleted file mode 100644 index 65ca182da2..0000000000 --- a/example/26_contraction/run_contraction_scale_example.inc +++ /dev/null @@ -1,217 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/numeric.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" - -int run_contraction_scale_example(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // A[M0, M1, K0, K1] - std::vector a_ms_ks_lengths{30, 128, 32, 64}; - std::vector a_ms_ks_strides{524288, 4096, 128, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{524288, 4096, 128, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{524288, 4096, 128, 1}; - - float scale = 1.f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 23) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - const ck::index_t M0 = std::stoi(argv[4]); - const ck::index_t M1 = std::stoi(argv[5]); - - const ck::index_t N0 = std::stoi(argv[6]); - const ck::index_t N1 = std::stoi(argv[7]); - - const ck::index_t K0 = std::stoi(argv[8]); - const ck::index_t K1 = std::stoi(argv[9]); - - a_ms_ks_lengths = {M0, M1, K0, K1}; - a_ms_ks_strides = { - std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])}; - - b_ns_ks_lengths = {N0, N1, K0, K1}; - b_ns_ks_strides = { - std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])}; - - e_ms_ns_lengths = {M0, M1, N0, N1}; - e_ms_ns_strides = { - std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])}; - - scale = std::stof(argv[22]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n"); - printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n"); - printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n"); - printf("arg18 to 21: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n"); - printf("arg22: scale\n"); - exit(0); - } - - Tensor a_ms_ks(a_ms_ks_lengths, a_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a_ms_ks: " << a_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a_device_buf(sizeof(ADataType) * a_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{scale}; - - // device operation - auto op = DeviceOpInstance{}; - auto invoker = op.MakeInvoker(); - auto argument = op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - e_device_buf.GetDeviceBuffer(), - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - std::array, 0>{}, - std::array, 0>{}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!op.IsSupportedArgument(argument)) - { - std::cout << op.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << op.GetTypeString() << std::endl; - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - if(do_verification) - { - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, a_element_op, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1)); - } - } - } - } - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/60_gemm_multi_ABD/CMakeLists.txt b/example/60_gemm_multi_ABD/CMakeLists.txt deleted file mode 100644 index 0c6ce7ae87..0000000000 --- a/example/60_gemm_multi_ABD/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() diff --git a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp b/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp deleted file mode 100644 index 90558364a4..0000000000 --- a/example/60_gemm_multi_ABD/gemm_multi_ABD_xdl_fp16.cpp +++ /dev/null @@ -1,362 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" -#include "ck/library/utility/check_err.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = F16; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -using ALayout = Row; -using BLayout = Col; -using DLayout = Row; -using ELayout = Row; - -struct AddScale -{ - static constexpr auto I0 = ck::Number<0>{}; - static constexpr auto I1 = ck::Number<1>{}; - static constexpr auto I2 = ck::Number<2>{}; - static constexpr auto I3 = ck::Number<3>{}; - - __host__ __device__ constexpr void - operator()(ck::half4_t& a, const ck::half4_t& a0, const ck::half4_t& a1) const - { - const auto a0_v_t = ck::vector_type{a0}; - const auto a1_v_t = ck::vector_type{a1}; - - auto r_v_t = ck::vector_type{}; - - r_v_t.AsType()(I0) = - scale * (a0_v_t.AsType()[I0] + a1_v_t.AsType()[I0]); - r_v_t.AsType()(I1) = - scale * (a0_v_t.AsType()[I1] + a1_v_t.AsType()[I1]); - r_v_t.AsType()(I2) = - scale * (a0_v_t.AsType()[I2] + a1_v_t.AsType()[I2]); - r_v_t.AsType()(I3) = - scale * (a0_v_t.AsType()[I3] + a1_v_t.AsType()[I3]); - - a = r_v_t.AsType()[I0]; - } - - __host__ __device__ constexpr void - operator()(ck::half_t& a, const ck::half_t& a0, const ck::half_t& a1) const - { - a = scale * (a0 + a1); - } - - // this attribute will force copy_function applying element_wise with vector_type - static constexpr ck::index_t vec_len = 4; - - float scale = 1.0; -}; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; - - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - - template <> - __host__ __device__ constexpr void operator()( - ck::half_t& e, const float& c, const ck::half_t& d) const - { - e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); - }; - - float alpha_; - float beta_; -}; - -using AElementOp = AddScale; -using BElementOp = PassThrough; -using CDEElementOp = AlphaBetaAdd; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleABD_Xdl_CShuffle< - ck::Tuple, - ck::Tuple, - ck::Tuple, - ELayout, - ck::Tuple, - ck::Tuple, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 1, - 256, - 256, - 128, - 32, - 8, - 8, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - S<1, 32, 1, 8>, - 8>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideD = 4096; - ck::index_t StrideE = 4096; - - float alpha = 1.0f; - float beta = 1.0f; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 6) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - alpha = std::stof(argv[4]); - beta = std::stof(argv[5]); - } - else if(argc == 13) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideD = std::stoi(argv[9]); - StrideE = std::stoi(argv[10]); - - alpha = std::stof(argv[11]); - beta = std::stof(argv[12]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, " - "beta\n"); - exit(0); - } - - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - using namespace ck::literals; - - if(std::is_same::value) - { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor a1_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{})); - Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); - - std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; - std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; - std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a0_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem a0_device_buf(sizeof(ADataType) * a0_m_k.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(ADataType) * a1_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_m_k.mData.data()); - a1_device_buf.ToDevice(a1_m_k.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); - d_device_buf.ToDevice(d_m_n.mData.data()); - e_device_buf.ToDevice(e_m_n_device_result.mData.data()); - - auto a_element_op = AElementOp{0.2}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, - std::array{b_device_buf.GetDeviceBuffer()}, - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - M, - N, - K, - std::array{StrideA, StrideA}, - std::array{StrideB}, - std::array{StrideD}, - StrideE, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - if(do_verification) - { - Tensor c_m_n({M, N}); - - Tensor a_m_k({M, K}); - - for(int m = 0; m < M; ++m) - { - for(int k = 0; k < K; ++k) - { - a_element_op(a_m_k(m, k), a0_m_k(m, k), a1_m_k(m, k)); - } - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = - ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, b_element_op, PassThrough{}); - - ref_invoker.Run(ref_argument); - - for(int m = 0; m < M; ++m) - { - for(int n = 0; n < N; ++n) - { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n)); - } - } - - e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - - return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/example/61_contraction_multi_ABD/CMakeLists.txt b/example/61_contraction_multi_ABD/CMakeLists.txt deleted file mode 100644 index 0420fdb9c3..0000000000 --- a/example/61_contraction_multi_ABD/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) -list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list2 AND target EQUAL 0) - add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp) - set(target 1) - endif() -endforeach() -endif() diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp deleted file mode 100644 index 5a274d485e..0000000000 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ /dev/null @@ -1,329 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/numeric.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using A0DataType = F16; -using A1DataType = F32; -using BDataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using DDataType = F16; -using EDataType = F16; - -static constexpr ck::index_t NumDimM = 2; -static constexpr ck::index_t NumDimN = 2; -static constexpr ck::index_t NumDimK = 2; - -struct AlphaBetaAdd -{ - AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; - - template - __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; - - template <> - __host__ __device__ constexpr void operator()( - ck::half_t& e, const float& c, const ck::half_t& d) const - { - e = ck::type_convert(alpha_ * c + beta_ * ck::type_convert(d)); - }; - - float alpha_; - float beta_; -}; - -struct Multiply -{ - __host__ __device__ constexpr void - operator()(ck::half_t& a, const ck::half_t& a0, const float& a1) const - { - a = ck::type_convert(ck::type_convert(a0) * a1); - } -}; - -using AElementOp = Multiply; -using BElementOp = PassThrough; -using CDEElementOp = AlphaBetaAdd; - -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using DeviceOpInstance = ck::tensor_operation::device::DeviceContractionMultipleABD_Xdl_CShuffle< - NumDimM, - NumDimN, - NumDimK, - ck::Tuple, - ck::Tuple, - AccDataType, - CShuffleDataType, - ck::Tuple, - EDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 1, - 256, - 256, - 128, - 32, - 8, - 8, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - 1, - 1, - 1, - S<1, 32, 1, 8>, - 8>; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - float alpha = 1.0f; - float beta = 1.0f; - - // A0[M0, M1, K0, K1] - std::vector a0_ms_ks_lengths{30, 128, 32, 64}; - std::vector a0_ms_ks_strides{128 * 32 * 64, 32 * 64, 64, 1}; - // A1[M1, K1] -> A1[M0, M1, K0, K1] - std::vector a1_ms_ks_lengths{30, 128, 32, 64}; - std::vector a1_ms_ks_strides{0, 64, 0, 1}; - // B[N0, N1, K0, K1] - std::vector b_ns_ks_lengths{32, 64, 32, 64}; - std::vector b_ns_ks_strides{64 * 32 * 64, 32 * 64, 64, 1}; - // D[M0, M1, N0, N1] - std::vector d_ms_ns_lengths{30, 128, 32, 64}; - std::vector d_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; - // E[M0, M1, N0, N1] - std::vector e_ms_ns_lengths{30, 128, 32, 64}; - std::vector e_ms_ns_strides{128 * 32 * 64, 32 * 64, 64, 1}; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } - - Tensor a0_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - Tensor a1_ms_ks(a1_ms_ks_lengths, a1_ms_ks_strides); - Tensor b_ns_ks(b_ns_ks_lengths, b_ns_ks_strides); - Tensor d_ms_ns(d_ms_ns_lengths, d_ms_ns_strides); - Tensor e_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - Tensor e_ms_ns_device_result(e_ms_ns_lengths, e_ms_ns_strides); - - std::cout << "a0_ms_ks: " << a0_ms_ks.mDesc << std::endl; - std::cout << "a1_ms_ks: " << a1_ms_ks.mDesc << std::endl; - std::cout << "b_ns_ks: " << b_ns_ks.mDesc << std::endl; - std::cout << "d_ms_ns: " << d_ms_ns.mDesc << std::endl; - std::cout << "e_ms_ns: " << e_ms_ns_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a0_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - a1_ms_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - a0_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - a1_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_ns_ks.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - d_ms_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - } - - DeviceMem a0_device_buf(sizeof(A0DataType) * a0_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem a1_device_buf(sizeof(A1DataType) * a1_ms_ks.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_ns_ks.mDesc.GetElementSpaceSize()); - DeviceMem d_device_buf(sizeof(DDataType) * d_ms_ns.mDesc.GetElementSpaceSize()); - DeviceMem e_device_buf(sizeof(EDataType) * e_ms_ns_device_result.mDesc.GetElementSpaceSize()); - - a0_device_buf.ToDevice(a0_ms_ks.mData.data()); - a1_device_buf.ToDevice(a1_ms_ks.mData.data()); - b_device_buf.ToDevice(b_ns_ks.mData.data()); - d_device_buf.ToDevice(d_ms_ns.mData.data()); - - // set zero - e_device_buf.SetZero(); - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{alpha, beta}; - - // do GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument( - std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, - std::array{b_device_buf.GetDeviceBuffer()}, - std::array{d_device_buf.GetDeviceBuffer()}, - e_device_buf.GetDeviceBuffer(), - std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, - std::array, 2>{a0_ms_ks_strides, a1_ms_ks_strides}, - std::array, 1>{b_ns_ks_lengths}, - std::array, 1>{b_ns_ks_strides}, - std::array, 1>{d_ms_ns_lengths}, - std::array, 1>{d_ms_ns_strides}, - e_ms_ns_lengths, - e_ms_ns_strides, - a_element_op, - b_element_op, - cde_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_contraction with the specified compilation parameters does " - "not support this problem"); - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - if(time_kernel) - { - ck::index_t M = - ck::accumulate_n(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{}); - - ck::index_t N = ck::accumulate_n( - e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{}); - - ck::index_t K = ck::accumulate_n( - a0_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(A0DataType) * M * K + sizeof(BDataType) * K * N + +sizeof(EDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; - } - - if(do_verification) - { - - Tensor c_ms_ns_host_result(e_ms_ns_lengths, e_ms_ns_strides); - - Tensor a_ms_ks(a0_ms_ks_lengths, a0_ms_ks_strides); - - for(size_t m0 = 0; m0 < a_ms_ks.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < a_ms_ks.mDesc.GetLengths()[1]; ++m1) - { - for(size_t k0 = 0; k0 < a_ms_ks.mDesc.GetLengths()[2]; ++k0) - { - for(size_t k1 = 0; k1 < a_ms_ks.mDesc.GetLengths()[3]; ++k1) - { - a_element_op(a_ms_ks(m0, m1, k0, k1), - a0_ms_ks(m0, m1, k0, k1), - a1_ms_ks(m0, m1, k0, k1)); - } - } - } - } - - using ReferenceOpInstance = - ck::tensor_operation::host::ReferenceContraction_M2_N2_K2; - - auto ref_op = ReferenceOpInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - - Tensor empty_tensor(std::vector{}, std::vector{}); - auto ref_argument = - ref_op.MakeArgument(a_ms_ks, b_ns_ks, c_ms_ns_host_result, PassThrough{}, b_element_op); - - ref_invoker.Run(ref_argument); - - for(size_t m0 = 0; m0 < e_ms_ns_host_result.mDesc.GetLengths()[0]; ++m0) - { - for(size_t m1 = 0; m1 < e_ms_ns_host_result.mDesc.GetLengths()[1]; ++m1) - { - for(size_t n0 = 0; n0 < e_ms_ns_host_result.mDesc.GetLengths()[2]; ++n0) - { - for(size_t n1 = 0; n1 < e_ms_ns_host_result.mDesc.GetLengths()[3]; ++n1) - { - cde_element_op(e_ms_ns_host_result(m0, m1, n0, n1), - c_ms_ns_host_result(m0, m1, n0, n1), - d_ms_ns(m0, m1, n0, n1)); - } - } - } - } - - e_device_buf.FromDevice(e_ms_ns_device_result.mData.data()); - - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; - } - - return 0; -} diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp deleted file mode 100644 index 3116d40474..0000000000 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp +++ /dev/null @@ -1,61 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/tensor_operation/gpu/device/device_base.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// GEMM: -// input : A0[M0, M1, ... K0, K1, ...], ... -// input : B0[N0, N1, ... K0, K1, ...], ... -// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... -// output : E[M0, M1, ... N0, N1, ...] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -// Assume: -// D0, D1, ... and E have the same layout -template -struct DeviceContractionMultipleABD : public BaseOperator -{ - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - virtual std::unique_ptr - MakeArgumentPointer(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) = 0; - - virtual std::unique_ptr MakeInvokerPointer() = 0; -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp index dffb51d0ba..118ade8978 100644 --- a/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp +++ b/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -33,8 +33,7 @@ template + typename CDEElementwiseOperation> struct DeviceContractionMultipleD : public BaseOperator { static constexpr index_t NumDTensor = DsDataType::Size(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp deleted file mode 100644 index a4c6767434..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ /dev/null @@ -1,846 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_contraction_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, - const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(p_as_grid, - p_bs_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - as_grid_desc_ak0_m_ak1, - bs_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); -#else - ignore = p_as_grid; - ignore = p_bs_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = as_grid_desc_ak0_m_ak1; - ignore = bs_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; -#endif -} - -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace device { - -// GEMM: -// input : A[M, K] -// input : B[N, K] -// input : D0[M, N], D1[M, N], ... -// output : E[M, N] -// C = a_op(A) * b_op(B) -// E = cde_op(C, D0, D1, ...) -// Assume: -// D0, D1, ... and E have the same layout -template -struct DeviceContractionMultipleABD_Xdl_CShuffle - : public DeviceContractionMultipleABD -{ - using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; - - static constexpr index_t NumATensor = AsDataType::Size(); - static constexpr index_t NumBTensor = BsDataType::Size(); - static constexpr index_t NumDTensor = DsDataType::Size(); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - using ComputeDataType = EDataType; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< - AsDataType, - BsDataType, - ComputeDataType, - AccDataType, - CShuffleDataType, - DsDataType, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, - NumGemmKPrefetchStage, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEBlockTransferScalarPerVector_NPerBlock, - LoopSched, - PipelineVer>; - - static constexpr auto matrix_padder = - ck::tensor_operation::device::MatrixPadder{ - MPerBlock, NPerBlock, KPerBlock}; - - static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, - const std::vector& a_ms_ks_strides_) - { - assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && - a_ms_ks_strides_.size() == NumDimM + NumDimK); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); - const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for K0, K1, ... - constexpr auto kDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); - - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] - const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); - - // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] - const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( - a_grid_desc_ms_ks, - make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), - make_tuple(mDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - } - - __host__ __device__ static auto - MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, - const std::array, NumATensor>& as_ms_ks_strides) - { - return generate_tuple( - [&](auto i) { - return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); - }, - Number{}); - } - - // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] - static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, - const std::vector& b_ns_ks_strides_) - { - assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && - b_ns_ks_strides_.size() == NumDimN + NumDimK); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); - const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; - - // dimension Ids for K0, K1, ... - constexpr auto kDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for K0, K1, ... - const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); - - // lengths for N0, N1, ... - const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); - - // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] - const auto b_grid_desc_ns_ks = - make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); - - // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] - const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( - b_grid_desc_ns_ks, - make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), - make_tuple(nDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - } - - __host__ __device__ static auto - MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, - const std::array, NumBTensor>& bs_ns_ks_strides) - { - return generate_tuple( - [&](auto i) { - return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); - }, - Number{}); - } - - // assume E[M0, M1, M2, ..., N0, N1, N2...] - static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, - const std::vector& e_ms_ns_strides_) - { - assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && - e_ms_ns_strides_.size() == NumDimM + NumDimN); - - const auto to_tuple = [&](auto& vec, auto num) { - return generate_tuple([&](auto i) { return vec[i]; }, num); - }; - - const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); - const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); - - // dimension Ids for M0, M1, ... - constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; - - // dimension Ids for N0, N1, ... - constexpr auto nDimIds = - typename arithmetic_sequence_gen::type{}; - - // lengths for M0, M1, ... - const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); - - // lengths for K0, K1, ... - const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); - - // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] - const auto e_grid_desc_ms_ns = - make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); - - // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] - const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( - e_grid_desc_ms_ns, - make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), - make_tuple(mDimIds, nDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - - static auto - MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, - const std::array, NumDTensor>& ds_ms_ns_strides) - { - return generate_tuple( - [&](auto i) { - return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); - }, - Number{}); - } - - // desc for problem definition - using AsGridDesc_M_K = remove_cvref_t; - using BsGridDesc_N_K = remove_cvref_t; - using DsGridDesc_M_N = remove_cvref_t; - using EGridDesc_M_N = remove_cvref_t; - - // desc for blockwise copy - using AsGridDesc_AK0_M_AK1 = - remove_cvref_t; - using BsGridDesc_BK0_N_BK1 = - remove_cvref_t; - using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< - decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - DsGridDesc_M_N{}))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; - - // block-to-e-tile map - using Block2ETileMap = - remove_cvref_t; - - // Argument - struct Argument : public BaseArgument - { - Argument(std::array p_as_grid, - std::array p_bs_grid, - std::array p_ds_grid, - void* p_e_grid, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - : p_as_grid_{}, - p_bs_grid_{}, - p_ds_grid_{}, - p_e_grid_{static_cast(p_e_grid)}, - as_grid_desc_m_k_{}, - bs_grid_desc_n_k_{}, - ds_grid_desc_m_n_{}, - e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, - as_grid_desc_ak0_m_ak1_{}, - bs_grid_desc_bk0_n_bk1_{}, - ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, - block_2_etile_map_{GridwiseGemm::MakeBlock2ETileMap(e_grid_desc_m_n_)}, - a_element_op_{a_element_op}, - b_element_op_{b_element_op}, - cde_element_op_{cde_element_op} - { - // populate pointer, desc for As - static_for<0, NumATensor, 1>{}([&](auto i) { - // using ALayout = remove_cvref_t>; - using ADataType = remove_cvref_t>; - - // A pointer - p_as_grid_(i) = static_cast(p_as_grid[i]); - - // A desc - as_grid_desc_m_k_(i) = - MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); - }); - - // populate pointer, desc for Bs - static_for<0, NumBTensor, 1>{}([&](auto i) { - // using BLayout = remove_cvref_t>; - using BDataType = remove_cvref_t>; - - // B pointer - p_bs_grid_(i) = static_cast(p_bs_grid[i]); - - // B desc - bs_grid_desc_n_k_(i) = - MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); - }); - - // populate pointer, desc for Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - // using DLayout = remove_cvref_t>; - using DDataType = remove_cvref_t>; - - // D pointer - p_ds_grid_(i) = static_cast(p_ds_grid[i]); - - // D desc - ds_grid_desc_m_n_(i) = - MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); - }); - - // populate desc for Ds/E - if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, - bs_grid_desc_n_k_, - ds_grid_desc_m_n_, - e_grid_desc_m_n_, - block_2_etile_map_)) - { - as_grid_desc_ak0_m_ak1_ = - GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); - - bs_grid_desc_bk0_n_bk1_ = - GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); - - ds_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n_); - - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n_); - } - - // for sanity check of vector memory access - for(index_t i = 0; i < NumATensor; ++i) - { - a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; - a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; - } - - for(index_t i = 0; i < NumBTensor; ++i) - { - b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; - b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; - } - - for(index_t i = 0; i < NumDTensor; ++i) - { - ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; - } - - e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; - } - - // pointers - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - EDataType* p_e_grid_; - - // tensor descriptors for problem definiton - AsGridDesc_M_K as_grid_desc_m_k_; - BsGridDesc_N_K bs_grid_desc_n_k_; - DsGridDesc_M_N ds_grid_desc_m_n_; - EGridDesc_M_N e_grid_desc_m_n_; - - // tensor descriptors for block/thread-wise copy - AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; - BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; - DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - - // block-to-e-tile map - Block2ETileMap block_2_etile_map_; - - // element-wise op - AElementwiseOperation a_element_op_; - BElementwiseOperation b_element_op_; - CDEElementwiseOperation cde_element_op_; - - // Strides for the last M/N/K dimensions of A/B/Ds/E - // for sanity check of vector load/store - std::array a_mz_stride_; - std::array a_kz_stride_; - - std::array b_nz_stride_; - std::array b_kz_stride_; - - std::array ds_nz_stride_; - index_t e_nz_stride_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - const index_t grid_size = - arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); - - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - - const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< - GridwiseGemm, - typename GridwiseGemm::AsGridPointer, - typename GridwiseGemm::BsGridPointer, - typename GridwiseGemm::DsGridPointer, - EDataType, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - DeviceOp::AsGridDesc_AK0_M_AK1, - DeviceOp::BsGridDesc_BK0_N_BK1, - DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - DeviceOp::Block2ETileMap, - has_main_loop>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_as_grid_, - arg.p_bs_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.as_grid_desc_ak0_m_ak1_, - arg.bs_grid_desc_bk0_n_bk1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_etile_map_); - }; - - const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); - - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static bool IsSupportedArgument(const Argument& arg) - { - if(!ck::is_xdl_supported()) - { - return false; - } - - // check vector load/store - { - bool all_valid = true; - - static_for<0, NumATensor, 1>{}([&](auto i) { - // vector memory access of A: could be on M or AK1 dimension - if constexpr(ABlockTransferSrcVectorDim == 1) - { - if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % - ABlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - }); - - // vector memory access of B: could be on N or BK1 dimension - static_for<0, NumBTensor, 1>{}([&](auto i) { - if constexpr(BBlockTransferSrcVectorDim == 1) - { - if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - else - { - if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) % - BBlockTransferSrcScalarPerVector == - 0)) - { - all_valid = false; - } - } - }); - - // check vector load of Ds - static_for<0, NumDTensor, 1>{}([&](auto i) { - if(!(arg.ds_nz_stride_[i] == 1 && - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - all_valid = false; - } - }); - - // vector memory access of E: always on NPerBlock dimension - if(!(arg.e_nz_stride_ == 1 && - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % - CDEBlockTransferScalarPerVector_NPerBlock == - 0)) - { - all_valid = false; - } - - if(!all_valid) - { - return false; - } - } - - return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, - arg.bs_grid_desc_n_k_, - arg.ds_grid_desc_m_n_, - arg.e_grid_desc_m_n_, - arg.block_2_etile_map_); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& a_ms_ks_lengths, - const std::array, NumATensor>& a_ms_ks_strides, - const std::array, NumBTensor>& b_ns_ks_lengths, - const std::array, NumBTensor>& b_ns_ks_strides, - const std::array, NumDTensor>& d_ms_ns_lengths, - const std::array, NumDTensor>& d_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) - { - return Argument{p_as, - p_bs, - p_ds, - p_e, - a_ms_ks_lengths, - a_ms_ks_strides, - b_ns_ks_lengths, - b_ns_ks_strides, - d_ms_ns_lengths, - d_ms_ns_strides, - e_ms_ns_length, - e_ms_ns_stride, - a_element_op, - b_element_op, - cde_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr - MakeArgumentPointer(std::array p_as, - std::array p_bs, - std::array p_ds, - void* p_e, - const std::array, NumATensor>& as_ms_ks_lengths, - const std::array, NumATensor>& as_ms_ks_strides, - const std::array, NumBTensor>& bs_ns_ks_lengths, - const std::array, NumBTensor>& bs_ns_ks_strides, - const std::array, NumDTensor>& ds_ms_ns_lengths, - const std::array, NumDTensor>& ds_ms_ns_strides, - const std::vector& e_ms_ns_length, - const std::vector& e_ms_ns_stride, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op) override - { - return std::make_unique(p_as, - p_bs, - p_ds, - p_e, - as_ms_ks_lengths, - as_ms_ks_strides, - bs_ns_ks_lengths, - bs_ns_ks_strides, - ds_ms_ns_lengths, - ds_ms_ns_strides, - e_ms_ns_length, - e_ms_ns_stride, - a_element_op, - b_element_op, - cde_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - std::map LoopSchedToString{ - {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; - - std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, - {PipelineVersion::v2, "v2"}}; - - // clang-format off - str << "DeviceContractionMultipleABD_Xdl_CShuffle" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << KPerBlock << ", " - << AK1 << ", " - << BK1 << ", " - << MPerXDL << ", " - << NPerXDL << ", " - << MXdlPerWave << ", " - << NXdlPerWave << ", " - << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector << ", " - << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << getGemmSpecializationString(GemmSpec) - << ">" - << " LoopScheduler: " - << LoopSchedToString[LoopSched] << ", " - << "PipelineVersion: " - << PipelineVersionToString[PipelineVer]; - // clang-format on - - return str.str(); - } -}; - -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index db119178b7..5ae06836fa 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -112,7 +112,6 @@ template + CDEElementwiseOperation> { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; @@ -312,6 +310,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + using ComputeDataType = ADataType; + // GridwiseGemm using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 1b9e37b74a..9fe0931cba 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -186,25 +186,6 @@ struct Bilinear y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); }; - template <> - __host__ __device__ constexpr void - operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const - { - const float x0_tmp = type_convert(x0); - const float x1_tmp = type_convert(x1); - const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp; - y = type_convert(y_tmp); - }; - - template <> - __host__ __device__ constexpr void - operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const - { - const float x1_tmp = ck::type_convert(x1); - const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; - y = y_tmp; - }; - template <> __host__ __device__ constexpr void operator()( std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 337117551f..0b0e731c60 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -33,12 +33,6 @@ struct PassThrough y = type_convert(x); } - template <> - __host__ __device__ void operator()(double& y, const float& x) const - { - y = type_convert(x); - } - template <> __host__ __device__ void operator()(float& y, const float& x) const { @@ -75,12 +69,6 @@ struct PassThrough y = type_convert(x); } - template <> - __host__ __device__ void operator()(float& y, const bhalf_t& x) const - { - y = type_convert(x); - } - template <> __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const { @@ -243,20 +231,6 @@ struct Scale template __host__ __device__ void operator()(Y& y, const X& x) const; - template <> - __host__ __device__ void operator()(half_t& y, const half_t& x) const - { - y = ck::type_convert(scale_) * x; - }; - - template <> - __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const - { - const float x_tmp = ck::type_convert(x); - const float y_tmp = scale_ * x_tmp; - y = ck::type_convert(y_tmp); - }; - template <> __host__ __device__ void operator()(float& y, const float& x) const { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 6321de73a9..10a0842dac 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -428,7 +428,7 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle [&](auto i) { using ALayout = remove_cvref_t>; - return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + return MakeAGridDescriptor_M_N(MRaws[i], KRaws[i], AsStride[i]); }, Number{}); } @@ -657,7 +657,6 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, ComputeDataType, - ComputeDataType, AccDataType, decltype(a_block_desc_ak0_m_ak1), decltype(b_block_desc_bk0_n_bk1), diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp index 527dac6d3b..92a8c82a6e 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_contraction.hpp @@ -23,7 +23,6 @@ template = false> @@ -70,24 +69,19 @@ struct ReferenceContraction_M2_N2_K2 : public ck::tensor_operation::device::Base { for(ck::index_t k1 = 0; k1 < K1; ++k1) { - // Simulate the possible casting when ComputeDataType is different than the - // A/B data types - ComputeDataType v_a_compute_input = - ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1)); - ComputeDataType v_b_compute_input = - ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1)); - AccDataType v_a; AccDataType v_b; - arg.a_element_op_(v_a, ck::type_convert(v_a_compute_input)); - arg.b_element_op_(v_b, ck::type_convert(v_b_compute_input)); + arg.a_element_op_( + v_a, ck::type_convert(arg.a_ms_ks_(m0, m1, k0, k1))); + arg.b_element_op_( + v_b, ck::type_convert(arg.b_ns_ks_(n0, n1, k0, k1))); v_acc += v_a * v_b; } } - arg.c_ms_ns_(m0, m1, n0, n1) = ck::type_convert(v_acc); + arg.c_ms_ns_(m0, m1, n0, n1) = v_acc; }; make_ParallelTensorFunctor(f_ms_ns, diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index f6fb3f9eee..ea11fd2e1a 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -29,8 +29,6 @@ using BF8 = ck::bf8_t; using Empty_Tuple = ck::Tuple<>; -using BF16_Tuple = ck::Tuple; - using F16_Tuple = ck::Tuple; using F16_F16_Tuple = ck::Tuple; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp deleted file mode 100644 index 96f4d4e30e..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp +++ /dev/null @@ -1,292 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; - -using F16_Tuple = ck::Tuple; -using BF16_Tuple = ck::Tuple; -using F32_Tuple = ck::Tuple; -using F64_Tuple = ck::Tuple; -using Empty_Tuple = ck::Tuple<>; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using Bilinear = ck::tensor_operation::element_wise::Bilinear; -using Scale = ck::tensor_operation::element_wise::Scale; - -template -using S = ck::Sequence; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -template -using device_contraction_kk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - -template -using device_contraction_kn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_mk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_mn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> - // clang-format on - >; - -template -using device_contraction_f64_kk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> - // clang-format on - >; - -template -using device_contraction_f64_kn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -template -using device_contraction_f64_mk_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -template -using device_contraction_f64_mn_instance = std::tuple< - // clang-format off - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, - DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, ComputeDataType, AElementwiseOp, BElementwiseOp, CDEElementwiseOp, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> - // clang-format on - >; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp index d06cab1196..bd3af891ec 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_bilinear.hpp @@ -17,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #ifdef CK_ENABLE_FP32 +// float void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP32 - + Bilinear>>>& instances); +#endif #ifdef CK_ENABLE_FP64 +// double void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( std::vector>>& instances); + Bilinear>>>& instances); void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP64 - -#ifdef CK_ENABLE_FP16 -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - -#ifdef CK_ENABLE_BF16 -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - std::vector>>& instances); - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - + Bilinear>>>& instances); +#endif // Contraction + Bilinear template + typename EDataType> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::Bilinear>> { using DeviceOp = DeviceContractionMultipleD; + ck::tensor_operation::element_wise::Bilinear>; static auto GetInstances() { std::vector> op_ptrs; #ifdef CK_ENABLE_FP32 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - op_ptrs); - } + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP32 +#endif #ifdef CK_ENABLE_FP64 if constexpr(is_same_v && is_same_v && - is_same_v) + is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - op_ptrs); - } + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( + op_ptrs); + add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP64 -#ifdef CK_ENABLE_FP16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_FP16 -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - op_ptrs); - add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_BF16 +#endif return op_ptrs; } }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp index 8e994d61ae..4dd735d062 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/contraction_scale.hpp @@ -17,6 +17,7 @@ namespace tensor_operation { namespace device { namespace instance { #ifdef CK_ENABLE_FP32 +// float void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP32 - + Scale>>>& instances); +#endif #ifdef CK_ENABLE_FP64 +// double void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( std::vector>>& instances); + Scale>>>& instances); void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP64 - -#ifdef CK_ENABLE_FP16 -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - -#ifdef CK_ENABLE_BF16 -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - std::vector>>& instances); - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - std::vector>>& instances); -#endif // CK_ENABLE_FP16 - + Scale>>>& instances); +#endif // Contraction + Scale template + typename EDataType> struct DeviceOperationInstanceFactory> + ck::tensor_operation::element_wise::Scale>> { using DeviceOp = DeviceContractionMultipleD; + ck::tensor_operation::element_wise::Scale>; static auto GetInstances() { @@ -430,113 +155,34 @@ struct DeviceOperationInstanceFactory) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - op_ptrs); - } + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP32 +#endif #ifdef CK_ENABLE_FP64 if constexpr(is_same_v && is_same_v && is_same_v) { if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( - op_ptrs); - } - else if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - op_ptrs); - } + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( + op_ptrs); + add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( + op_ptrs); } } -#endif // CK_ENABLE_FP64 -#ifdef CK_ENABLE_FP16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_FP16 -#ifdef CK_ENABLE_BF16 - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2) - { - if constexpr(is_same_v) - { - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - op_ptrs); - add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - op_ptrs); - } - } - } -#endif // CK_ENABLE_BF16 +#endif return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt index 87a6bbba4e..ee9c419e1a 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/CMakeLists.txt @@ -1,42 +1,14 @@ set(DEVICE_CONTRACTION_BILINEAR_INSTANCES) - -# FP32 +#float list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp) - -# FP64 +#double list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp) + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp + device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp) add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp deleted file mode 100644 index 56dc1d2c71..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp deleted file mode 100644 index 7b8163814e..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp deleted file mode 100644 index ba202202f9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp deleted file mode 100644 index b6acab1c18..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp deleted file mode 100644 index 1bd3495b36..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp deleted file mode 100644 index 6817546d3d..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp deleted file mode 100644 index cedace2cb2..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp deleted file mode 100644 index b8383e5268..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp deleted file mode 100644 index 2e73bdb804..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp deleted file mode 100644 index ae69db03ac..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp deleted file mode 100644 index fc6b6003f7..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp deleted file mode 100644 index b81984acc4..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_bf16_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp deleted file mode 100644 index 8da6a79f0a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance = - device_contraction_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp deleted file mode 100644 index b05c443e85..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance = - device_contraction_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp deleted file mode 100644 index c2cbb9f25f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance = - device_contraction_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp deleted file mode 100644 index 887eaad7eb..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_compute_f16_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp index 9a3c71acb5..5587db77e0 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,40 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = - device_contraction_kk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp index 5659d477d4..26262855ea 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = - device_contraction_kn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp index 7cbe1e34be..befc0dcd10 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = - device_contraction_mk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp index e26dcf394b..e45b47cf94 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { +using F32 = float; +using F32_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = - device_contraction_mn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, F32_Tuple, F32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp deleted file mode 100644 index 74c5c9c653..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance = - device_contraction_f64_kk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_kknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp deleted file mode 100644 index 8a2705f722..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance = - device_contraction_f64_kn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_knnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp deleted file mode 100644 index 47924a843d..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance = - device_contraction_f64_mk_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mknn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp deleted file mode 100644 index d6c684e2b0..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance = - device_contraction_f64_mn_instance; - -void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_compute_f32_mnnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp index 75bd8e3f4a..f437a227d5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = - device_contraction_f64_kk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp index c515f37916..13fdbeb35c 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = - device_contraction_f64_kn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp index ff9fa61e0b..95ef8c4929 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = - device_contraction_f64_mk_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp index f274c45b58..290f81d7c9 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_bilinear/device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using F64_Tuple = ck::Tuple; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = - device_contraction_f64_mn_instance; +using device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, F64_Tuple, F64, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance( std::vector>>& instances) + Bilinear>>>& instances) { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt index a0918d9d3f..673fae3373 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/CMakeLists.txt @@ -1,43 +1,15 @@ set(DEVICE_CONTRACTION_SCALE_INSTANCES) - -# FP32 +#float list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp) - -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp) - -# FP64 +#double list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp) -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp) - -# FP16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp) - -# BF16 -list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp) - add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp deleted file mode 100644 index ace4d4a330..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp deleted file mode 100644 index d2a1d45f7a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp deleted file mode 100644 index aa4e4df9d9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp deleted file mode 100644 index a59bce0a1f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp deleted file mode 100644 index 876acd3f69..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp deleted file mode 100644 index 96e961d3bb..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp deleted file mode 100644 index 69cae90889..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp deleted file mode 100644 index e6ac98273a..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp deleted file mode 100644 index da1df4c43e..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp deleted file mode 100644 index 21c4e87fb1..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp deleted file mode 100644 index 9d22f8feff..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp deleted file mode 100644 index 8e861f74c9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_bf16_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp deleted file mode 100644 index 390b5526e3..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance = - device_contraction_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp deleted file mode 100644 index 385a13cc25..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance = - device_contraction_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp deleted file mode 100644 index 1afd432bac..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance = - device_contraction_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp deleted file mode 100644 index bfb4d4630f..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance = - device_contraction_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_compute_f16_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp index 51bc64eaad..16fd1cb407 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,40 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = - device_contraction_kk_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp index 4cd28d52dd..ff37bf7cce 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = - device_contraction_kn_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// k/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 1, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 1, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 1, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 1, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp index 82bba62715..8a1f6f9334 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = - device_contraction_mk_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/k/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 4, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 4, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp index 06bf4f34b5..d333f59726 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,43 @@ namespace tensor_operation { namespace device { namespace instance { -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = - device_contraction_mn_instance; +using F32 = float; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] = E[m0, m1, n0, n1] +// m/n/n are the fast changing dimension for A/B/E +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 1, 1, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp deleted file mode 100644 index c047153bd9..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance = - device_contraction_f64_kk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_kkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp deleted file mode 100644 index ee33c56223..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance = - device_contraction_f64_kn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_knn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp deleted file mode 100644 index b62d4bb2aa..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance = - device_contraction_f64_mk_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mkn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp deleted file mode 100644 index 346f91beb0..0000000000 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -// This (ifndef) is a hack to use customized behavior for buffer load rather than using default -// setting Don't use this hack unless absolutely necessary! -// FIXME: make the behavior of buffer load a configurable (template) parameter of each device op -#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -// A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] -// m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance = - device_contraction_f64_mn_instance; - -void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32_mnn_instance{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp index 30cc3b9652..4c87b51a9e 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = - device_contraction_f64_kk_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 32, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 64, 32, 16, 2, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 64, 32, 64, 16, 2, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp index 0a36ef7ffd..fd3f57c6b6 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // k/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = - device_contraction_f64_kn_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 1, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 1, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp index 617332c2c7..1e53f0b2f5 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/k/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = - device_contraction_f64_mk_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 2, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 2, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 2, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance{}); diff --git a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp index 1df7f31201..d02d146a9b 100644 --- a/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/contraction_scale/device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp @@ -9,9 +9,11 @@ #include #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/tensor_operation_instance/gpu/contraction/device_contraction_instance.hpp" + #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" namespace ck { @@ -19,19 +21,37 @@ namespace tensor_operation { namespace device { namespace instance { +using F64 = double; +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + // A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1] // m/n/n/n are the fast changing dimension for A/B/D/E -using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = - device_contraction_f64_mn_instance; +using device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance = std::tuple< + // clang-format off + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 1, 1, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 128, 64, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 1, 1, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 128, 64, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 8, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 1, 1, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 128, 64, 16, 2, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 0, 1, 1, S<1, 16, 1, 16>, 1>, + DeviceContractionMultipleD_Xdl_CShuffle< 2, 2, 2, F64, F64, F64, F64, Empty_Tuple, F64, PassThrough, PassThrough, Scale, GemmMNKPadding, 1, 256, 64, 128, 16, 2, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, 1, 1, 1, S<1, 16, 1, 16>, 1> + // clang-format on + >; void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance( std::vector>>& instances) + Scale>>>& instances) { add_device_operation_instances( instances, device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance{}); diff --git a/profiler/README.md b/profiler/README.md index b7b5be420b..d081fc33e9 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -50,23 +50,21 @@ Best Perf: 1.42509 ms, 102.988 TFlops, 234.086 GB/s ## Profile contraction kernels ```bash #arg1: tensor operation (contraction_bilinear=CONTRACTION+Bilinear) -#arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16) -#arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16) -#arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; +#arg2: data type (0: fp32; 1: f64)\n" +#arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]; # 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + D[m0, m1, n0, n1] = E[m0, m1, n0, n1]) -#arg5: verification (0: no; 1: yes) -#arg6: initialization (0: no init; 1: integer value; 2: decimal value) -#arg7: print tensor value (0: no; 1: yes) -#arg8: time kernel (0: no, 1: yes) -#arg9: alpha -#arg10: beta -#arg11 to 16: M0, M1, N0, N1, K0, K1 -#arg17 to 32: Strides for A, B, D and E (skip for default) +#arg4: verification (0: no; 1: yes) +#arg5: initialization (0: no init; 1: integer value; 2: decimal value) +#arg6: print tensor value (0: no; 1: yes) +#arg7: time kernel (0: no, 1: yes) +#arg8 and arg9: alpha and beta +#arg10 to 15: M0, M1, N0, N1, K0, K1 +#arg16 to 31: Strides for A, B, D and E (skip for default) -################ op datatype compute_datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 -./bin/ckProfiler contraction_bilinear 0 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 +################ op datatype layout verify init log time alpha beta M0 M1 N0 N1 K0 K1 +./bin/ckProfiler contraction_bilinear 0 1 0 0 0 1 1.0 1.0 128 128 128 128 128 128 ``` Result (MI100) diff --git a/profiler/include/profiler/profile_contraction_impl.hpp b/profiler/include/profiler/profile_contraction_impl.hpp index f6e4b3f395..660cc3f9e5 100644 --- a/profiler/include/profiler/profile_contraction_impl.hpp +++ b/profiler/include/profiler/profile_contraction_impl.hpp @@ -31,14 +31,10 @@ namespace profiler { using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; -using F32 = float; -using F64 = double; - template int profile_contraction_impl(ck::index_t do_verification, @@ -49,10 +45,10 @@ int profile_contraction_impl(ck::index_t do_verification, const std::vector& M, const std::vector& N, const std::vector& K, - const std::vector& StridesA, // [M0, M1, K0, K1] - const std::vector& StridesB, // [N0, N1, K0, K1] - const std::vector& StridesE, // [M0, M1, N0, N1] - const std::vector& StridesD) // [M0, M1, N0, N1] + const std::vector& StridesA, + const std::vector& StridesB, + const std::vector& StridesE, + const std::vector& StridesD) { bool pass = true; @@ -67,13 +63,13 @@ int profile_contraction_impl(ck::index_t do_verification, }; Tensor a_m_k(f_host_tensor_descriptor(M, K, StridesA)); - Tensor b_n_k(f_host_tensor_descriptor(N, K, StridesB)); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StridesB)); Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StridesE)); Tensor d_m_n(f_host_tensor_descriptor(M, N, StridesD)); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "b_n_k: " << b_n_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "d_m_n: " << d_m_n.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; @@ -82,12 +78,12 @@ int profile_contraction_impl(ck::index_t do_verification, case 0: break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); d_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d_m_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } @@ -95,12 +91,12 @@ int profile_contraction_impl(ck::index_t do_verification, using BElementOp = ck::tensor_operation::element_wise::PassThrough; DeviceMem a_device_buf(sizeof(DataType) * a_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(DataType) * b_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(DataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(DataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem d_device_buf(sizeof(DataType) * d_m_n.mDesc.GetElementSpaceSize()); a_device_buf.ToDevice(a_m_k.mData.data()); - b_device_buf.ToDevice(b_n_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); e_device_buf.SetZero(); d_device_buf.ToDevice(d_m_n.mData.data()); @@ -122,8 +118,7 @@ int profile_contraction_impl(ck::index_t do_verification, DataType, AElementOp, BElementOp, - CDElementOp, - ComputeDataType>; + CDElementOp>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -131,9 +126,6 @@ int profile_contraction_impl(ck::index_t do_verification, std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - using AccDataType = - typename std::conditional::value, F64, F32>::type; - // Run reference op if(do_verification) { @@ -144,8 +136,7 @@ int profile_contraction_impl(ck::index_t do_verification, DataType, DataType, DataType, - AccDataType, - ComputeDataType, + DataType, AElementOp, BElementOp>; @@ -155,7 +146,7 @@ int profile_contraction_impl(ck::index_t do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StridesE)); auto ref_argument = - ref_op.MakeArgument(a_m_k, b_n_k, c_m_n_host_result, a_element_op, b_element_op); + ref_op.MakeArgument(a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op); ref_invoker.Run(ref_argument); @@ -281,29 +272,8 @@ int profile_contraction_impl(ck::index_t do_verification, { e_device_buf.FromDevice(e_m_n_device_result.mData.data()); - // Both the kernel and the reference use `AccDataType`, so an absolute error of both - // of them is bounded by `nelems_k * std::numeric_limits::epsilon()`. - // Comparing one to another can result in an absolute error as high as twice that - // value. - double threshold = 2 * nelems_k * std::numeric_limits::epsilon(); - // Handle the possible casting error of either AccDataType -> DataType or - // DataType -> ComputeDataType. - // TODO: Add a generic solution for calculating thresholds in CK. - if constexpr(ck::is_same_v || - ck::is_same_v) - { - const double epsilon = std::pow(2, -7); - // Maximum relative casting error when rounding to zero. - threshold += epsilon * 2; - } - else if constexpr(ck::is_same_v || - ck::is_same_v) - { - const double epsilon = std::pow(2, -10); - // Maximum relative casting error when rounding to zero. - threshold += epsilon * 2; - } - + float threshold = + static_cast(nelems_k) * std::numeric_limits::epsilon(); pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result, "Error: incorrect results!", @@ -313,7 +283,7 @@ int profile_contraction_impl(ck::index_t do_verification, if(do_log) { LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b_n_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_host : ", e_m_n_host_result.mData, ",") << std::endl; LogRangeAsType(std::cout << "c_device: ", e_m_n_device_result.mData, ",") diff --git a/profiler/include/profiler/profile_contraction_utils.hpp b/profiler/include/profiler/profile_contraction_utils.hpp index 05ec7daf9d..076bbd4559 100644 --- a/profiler/include/profiler/profile_contraction_utils.hpp +++ b/profiler/include/profiler/profile_contraction_utils.hpp @@ -23,18 +23,8 @@ enum struct ContractionMatrixLayout enum struct ContractionDataType { - F32_F32_F32_F32, // 0 - F64_F64_F64_F64, // 1 - F16_F16_F16_F16, // 2 - BF16_BF16_BF16_BF16, // 3 -}; - -enum struct ContractionComputeDataType -{ - F32 = 0, - F64, - F16, - BF16, + F32_F32_F32_F32, // 0 + F64_F64_F64_F64, // 1 }; inline void collect_index_params(char* argv[], diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 8cb34ccc6c..6ed1841204 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -17,9 +17,8 @@ static void print_helper_msg() { std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,42 +26,40 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10: beta\n" - << "arg11 to 16: M0, M1, N0, N1, K0, K1\n" - << "arg17 to 32: Strides for A, B, D and E (skip for default)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8 and arg9: alpha and beta\n" + << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" + << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" << std::endl; } int profile_contraction_bilinear(int argc, char* argv[]) { - const bool default_strides = argc == 17; + const bool default_strides = argc == 16; - if(argc != 33 && argc != 17) + if(argc != 32 && argc != 16) { print_helper_msg(); exit(1); } const auto data_type = static_cast(std::stoi(argv[2])); - const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); - const float beta = std::stof(argv[10]); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); + const float beta = std::stof(argv[9]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 11; + const ck::index_t dims_arg_num = 10; collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, K, dims_arg_num + 4, 2); @@ -79,130 +76,90 @@ int profile_contraction_bilinear(int argc, char* argv[]) collect_index_params(argv, StridesD, dims_arg_num + 18, 4); } - using F16 = ck::half_t; - using BF16 = ck::bhalf_t; - using F32 = float; - using F64 = double; + using F32 = float; + using F64 = double; - auto profile = - [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CDELayout = decltype(cde_layout); + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); - using DataType = decltype(type); - using ComputeDataType = decltype(compute_type); + using DataType = decltype(type); - if(default_strides) - { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); - } - bool pass = ck::profiler::profile_contraction_impl, - Bilinear>(do_verification, - init_method, - do_log, - time_kernel, - Bilinear{alpha, beta}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; - }; - - auto run_profile_for_datatype = [&](auto type, auto compute_type) { - if(layout == ContractionMatrixLayout::MK_KN_MN_MN) + if(default_strides) { - return profile(Row{}, Row{}, Row{}, type, compute_type); + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); } - else if(layout == ContractionMatrixLayout::MK_NK_MN_MN) - { - return profile(Row{}, Col{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_KN_MN_MN) - { - return profile(Col{}, Row{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_NK_MN_MN) - { - return profile(Col{}, Col{}, Row{}, type, compute_type); - } - return false; + bool pass = ck::profiler::profile_contraction_impl, + Bilinear>(do_verification, + init_method, + do_log, + time_kernel, + Bilinear{alpha, beta}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; }; - if(data_type == ContractionDataType::F32_F32_F32_F32) + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F32{}, F32{}); - } - else if(compute_data_type == ContractionComputeDataType::F16) - { - return run_profile_for_datatype(F32{}, F16{}); - } - else if(compute_data_type == ContractionComputeDataType::BF16) - { - return run_profile_for_datatype(F32{}, BF16{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F64_F64_F64_F64) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F64) - { - return run_profile_for_datatype(F64{}, F64{}); - } - else if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F64{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Col{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F16_F16_F16_F16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(BF16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; } - return 1; } REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear); diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index ca9c199988..6784b916f6 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -17,9 +17,8 @@ static void print_helper_msg() { std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" - << "arg2: data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg3: compute data type (0: fp32; 1: f64; 2: f16; 3: bf16)\n" - << "arg4: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " + << "arg2: data type (0: fp32; 1: f64)\n" + << "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" @@ -27,40 +26,39 @@ static void print_helper_msg() "D[m0, m1, n0, n1] = E[m0, m1, n0, n1];\n" << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" - << "arg5: verification (0: no; 1: yes)\n" - << "arg6: initialization (0: no init; 1: integer value; 2: decimal " + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" - << "arg7: print tensor value (0: no; 1: yes)\n" - << "arg8: time kernel (0: no, 1: yes)\n" - << "arg9: alpha\n" - << "arg10 to 15: M0, M1, N0, N1, K0, K1\n" - << "arg16 to 31: Strides for A, B, D and E (skip for default)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "arg8: alpha\n" + << "arg9 to 14: M0, M1, N0, N1, K0, K1\n" + << "arg15 to 30: Strides for A, B, D and E (skip for default)\n" << std::endl; } int profile_contraction_scale(int argc, char* argv[]) { - const bool default_strides = argc == 16; + const bool default_strides = argc == 15; - if(argc != 32 && argc != 16) + if(argc != 31 && argc != 15) { print_helper_msg(); exit(1); } const auto data_type = static_cast(std::stoi(argv[2])); - const auto compute_data_type = static_cast(std::stoi(argv[3])); - const auto layout = static_cast(std::stoi(argv[4])); - const bool do_verification = std::stoi(argv[5]); - const ck::index_t init_method = std::stoi(argv[6]); - const bool do_log = std::stoi(argv[7]); - const bool time_kernel = std::stoi(argv[8]); - const float alpha = std::stof(argv[9]); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const ck::index_t init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + const float alpha = std::stof(argv[8]); std::vector M; std::vector N; std::vector K; - const ck::index_t dims_arg_num = 10; + const ck::index_t dims_arg_num = 9; collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, K, dims_arg_num + 4, 2); @@ -77,131 +75,88 @@ int profile_contraction_scale(int argc, char* argv[]) collect_index_params(argv, StridesD, dims_arg_num + 18, 4); } - using F16 = ck::half_t; - using BF16 = ck::bhalf_t; - using F32 = float; - using F64 = double; + using F32 = float; + using F64 = double; - auto profile = - [&](auto a_layout, auto b_layout, auto cde_layout, auto type, auto compute_type) { - using ALayout = decltype(a_layout); - using BLayout = decltype(b_layout); - using CDELayout = decltype(cde_layout); + auto profile = [&](auto a_layout, auto b_layout, auto cde_layout, auto type) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CDELayout = decltype(cde_layout); - using DataType = decltype(type); - using ComputeDataType = decltype(compute_type); + using DataType = decltype(type); - if(default_strides) - { - assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(b_layout, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); - assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); - } - - bool pass = ck::profiler::profile_contraction_impl, - Scale>(do_verification, - init_method, - do_log, - time_kernel, - Scale{alpha}, - M, - N, - K, - StridesA, - StridesB, - StridesE, - StridesD); - - return pass; - }; - - auto run_profile_for_datatype = [&](auto type, auto compute_type) { - if(layout == ContractionMatrixLayout::MK_KN_MN_MN) + if(default_strides) { - return profile(Row{}, Row{}, Row{}, type, compute_type); + assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]}); + assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesE, {M[0], M[1], N[0], N[1]}); + assign_default_strides(cde_layout, StridesD, {M[0], M[1], N[0], N[1]}); } - else if(layout == ContractionMatrixLayout::MK_NK_MN_MN) - { - return profile(Row{}, Col{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_KN_MN_MN) - { - return profile(Col{}, Row{}, Row{}, type, compute_type); - } - else if(layout == ContractionMatrixLayout::KM_NK_MN_MN) - { - return profile(Col{}, Col{}, Row{}, type, compute_type); - } - return false; + + bool pass = ck::profiler:: + profile_contraction_impl, Scale>( + do_verification, + init_method, + do_log, + time_kernel, + Scale{alpha}, + M, + N, + K, + StridesA, + StridesB, + StridesE, + StridesD); + + return pass; }; - if(data_type == ContractionDataType::F32_F32_F32_F32) + if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F32{}, F32{}); - } - else if(compute_data_type == ContractionComputeDataType::F16) - { - return run_profile_for_datatype(F32{}, F16{}); - } - else if(compute_data_type == ContractionComputeDataType::BF16) - { - return run_profile_for_datatype(F32{}, BF16{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F64_F64_F64_F64) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F64) - { - return run_profile_for_datatype(F64{}, F64{}); - } - else if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F64{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Row{}, Col{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::F16_F16_F16_F16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(F16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Row{}, Row{}, F32{}); } - else if(data_type == ContractionDataType::BF16_BF16_BF16_BF16) + else if(data_type == ContractionDataType::F32_F32_F32_F32 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) { - if(compute_data_type == ContractionComputeDataType::F32) - { - return run_profile_for_datatype(BF16{}, F32{}); - } - else - { - std::cout << "Incorrect combination of data type and compute data type." << std::endl; - return 1; - } + return profile(Col{}, Col{}, Row{}, F32{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_KN_MN_MN) + { + return profile(Row{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::MK_NK_MN_MN) + { + return profile(Row{}, Col{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_KN_MN_MN) + { + return profile(Col{}, Row{}, Row{}, F64{}); + } + else if(data_type == ContractionDataType::F64_F64_F64_F64 && + layout == ContractionMatrixLayout::KM_NK_MN_MN) + { + return profile(Col{}, Col{}, Row{}, F64{}); + } + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; } - return 1; } REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale); diff --git a/test/contraction/test_contraction.cpp b/test/contraction/test_contraction.cpp index 958d5be38c..c86b849235 100644 --- a/test/contraction/test_contraction.cpp +++ b/test/contraction/test_contraction.cpp @@ -10,12 +10,9 @@ #include #include "profiler/profile_contraction_impl.hpp" -#include "profiler/profile_contraction_utils.hpp" -using F16 = ck::half_t; -using BF16 = ck::bhalf_t; -using F32 = float; -using F64 = double; +using F32 = float; +using F64 = double; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -23,49 +20,49 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using Bilinear = ck::tensor_operation::element_wise::Bilinear; using Scale = ck::tensor_operation::element_wise::Scale; -struct Dimensions +struct MemoryParams { std::vector M; std::vector N; std::vector K; + std::vector StridesA; + std::vector StridesB; + std::vector StridesC; + std::vector StridesD; }; template class TestContraction : public ::testing::Test { protected: - using ALayout = std::tuple_element_t<0, Tuple>; - using BLayout = std::tuple_element_t<1, Tuple>; - using CDLayout = std::tuple_element_t<2, Tuple>; - using DataType = std::tuple_element_t<3, Tuple>; - using DTupleDataType = std::tuple_element_t<4, Tuple>; - using ComputeDataType = std::tuple_element_t<5, Tuple>; - using CDElementOp = std::tuple_element_t<6, Tuple>; + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CDLayout = std::tuple_element_t<2, Tuple>; + using DataType = std::tuple_element_t<3, Tuple>; + using DTupleDataType = std::tuple_element_t<4, Tuple>; + using CDElementOp = std::tuple_element_t<5, Tuple>; - std::vector dimension_list = {{{32, 32}, {32, 32}, {32, 32}}, - {{16, 16}, {32, 32}, {16, 16}}}; + std::vector list_of_memory_params = {{{32, 32}, + {32, 32}, + {32, 32}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}, + {32768, 1024, 32, 1}}, + {{16, 16}, + {32, 32}, + {16, 16}, + {4096, 256, 16, 1}, + {16, 1, 8192, 256}, + {16384, 1024, 32, 1}, + {16384, 1024, 32, 1}}}; - std::vector init_methods = {1, 2}; + std::vector init_methods = {0, 1, 2}; std::unique_ptr p_cd_element_op; - void Run() { - for(auto& dimension_params : dimension_list) + for(auto& memory_params : list_of_memory_params) { - std::vector StridesA; - std::vector StridesB; - std::vector StridesC; - std::vector StridesD; - - const auto& M = dimension_params.M; - const auto& N = dimension_params.N; - const auto& K = dimension_params.K; - - assign_default_strides(ALayout{}, StridesA, {M[0], M[1], K[0], K[1]}); - assign_default_strides(BLayout{}, StridesB, {N[0], N[1], K[0], K[1]}); - assign_default_strides(CDLayout{}, StridesC, {M[0], M[1], N[0], N[1]}); - assign_default_strides(CDLayout{}, StridesD, {M[0], M[1], N[0], N[1]}); - for(const ck::index_t init_method : init_methods) { bool pass = @@ -73,20 +70,19 @@ class TestContraction : public ::testing::Test BLayout, CDLayout, DataType, - ComputeDataType, DTupleDataType, CDElementOp>(true /*do_verification*/, init_method, false /*do_logs*/, false /*time_kernel*/, *p_cd_element_op, - dimension_params.M, - dimension_params.N, - dimension_params.K, - StridesA, - StridesB, - StridesC, - StridesD); + memory_params.M, + memory_params.N, + memory_params.K, + memory_params.StridesA, + memory_params.StridesB, + memory_params.StridesC, + memory_params.StridesD); EXPECT_TRUE(pass); } } @@ -103,18 +99,24 @@ class TestContractionBilinear : public TestContraction { }; -#define ALL_LAYOUT_COMBINATIONS(dt, tuple_dt, compute_dt, op) \ - std::tuple, \ - std::tuple, \ - std::tuple, \ - std::tuple - using BilinearKernelTypes = - ::testing::Types, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F64, Bilinear)>; + ::testing::Types, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>, + std::tuple, Bilinear>>; -using ScaleKernelTypes = ::testing::Types, F32, Scale), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F64, Scale)>; +using ScaleKernelTypes = ::testing::Types, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>, + std::tuple, Scale>>; TYPED_TEST_SUITE(TestContractionBilinear, BilinearKernelTypes); TYPED_TEST_SUITE(TestContractionScale, ScaleKernelTypes); @@ -134,46 +136,3 @@ TYPED_TEST(TestContractionScale, scale) this->p_cd_element_op = std::make_unique(0.5f); this->Run(); } - -template -class TestContractionScaleMixedPrecision : public TestContraction -{ -}; - -template -class TestContractionBilinearMixedPrecision : public TestContraction -{ -}; - -using BilinearKernelTypesMixedPrecision = - ::testing::Types, F16, Bilinear), - ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple, BF16, Bilinear), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple, F32, Bilinear), - ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple, F32, Bilinear)>; - -using ScaleKernelTypesMixedPrecision = - ::testing::Types, F16, Scale), - ALL_LAYOUT_COMBINATIONS(F32, ck::Tuple<>, BF16, Scale), - ALL_LAYOUT_COMBINATIONS(F64, ck::Tuple<>, F32, Scale), - ALL_LAYOUT_COMBINATIONS(F16, ck::Tuple<>, F32, Scale), - ALL_LAYOUT_COMBINATIONS(BF16, ck::Tuple<>, F32, Scale)>; - -TYPED_TEST_SUITE(TestContractionBilinearMixedPrecision, BilinearKernelTypesMixedPrecision); -TYPED_TEST_SUITE(TestContractionScaleMixedPrecision, ScaleKernelTypesMixedPrecision); - -TYPED_TEST(TestContractionBilinearMixedPrecision, bilinear) -{ - this->p_cd_element_op = std::make_unique(1.f, 1.f); - this->Run(); - this->p_cd_element_op = std::make_unique(-0.5f, 0.5f); - this->Run(); -} - -TYPED_TEST(TestContractionScaleMixedPrecision, scale) -{ - this->p_cd_element_op = std::make_unique(1.f); - this->Run(); - this->p_cd_element_op = std::make_unique(0.5f); - this->Run(); -} diff --git a/test/contraction/test_contraction_interface.cpp b/test/contraction/test_contraction_interface.cpp index e0f547445e..12a307f5da 100644 --- a/test/contraction/test_contraction_interface.cpp +++ b/test/contraction/test_contraction_interface.cpp @@ -34,11 +34,11 @@ class ContractionInstanceWrapper static constexpr ck::index_t NumDim = 2; // clang-format off using ContractionDeviceInstance = ck::tensor_operation::device:: - //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| Compute| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //#####################################| | | | Type| Type| Type| DataType| Type| Type| Data| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //#####################################| | | | | | | | | | Type| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; + //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceContractionMultipleD_Xdl_CShuffle< NumDim, NumDim, NumDim, F32, F32, F32, F32, ck::Tuple, F32, Pass, Pass, Bilinear, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, ABlockTransferSrcVectorDim, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, BBlockTransferSrcVectorDim, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, CDEBlockTransferScalarPerVector>; // clang-format on bool isSupported(std::vector& ADims,