From 4581b5d504ad400dd36e641856e1a2ce7d91955d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 8 Feb 2024 21:35:39 -0800 Subject: [PATCH 01/36] Bump rocm-docs-core from 0.33.2 to 0.34.0 in /docs/sphinx (#1163) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.33.2 to 0.34.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.33.2...v0.34.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index a6b286b131..65341af8d6 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.33.2 +rocm-docs-core==0.34.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 4bbe95c934..74016ea8a2 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.33.2 +rocm-docs-core==0.34.0 # via -r requirements.in six==1.16.0 # via From 94fbaac0027245b41869e9ed6f2d46b10c432745 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Fri, 9 Feb 2024 10:20:53 -0600 Subject: [PATCH 02/36] add generic instances for DeviceGemm_Xdl_CShuffle (#1161) * add generic instances * clean code --- ...xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp | 12 ++++++++++++ ...xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp | 14 ++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp index 4a2526b3a4..8c9a96f6b7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -34,6 +34,15 @@ static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecializati static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + DeviceGemm_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + // Compilation parameters for a[m, k] * b[k, n] = c[m, n] template using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = std::tuple< @@ -108,6 +117,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( DeviceGemm>>& instances) { + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_generic_instances{}); + add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp index 01e0ebdb34..b591dacff5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -32,6 +32,17 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances = std::tuple< + // clang-format off + //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler| Pipeline| + //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | | + //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, MNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 8>, 1, LoopScheduler::Default, PipelineVersion::v1> + // clang-format on + >; + template // Compilation parameters for a[m, k] * b[n, k] = c[m, n] using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = std::tuple< @@ -97,6 +108,9 @@ void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( DeviceGemm>>& instances) { + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_generic_instances{}); + add_device_operation_instances( instances, device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); From 602c4cc0d934e78bf2f2c07d95916398014bd767 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 12 Feb 2024 11:45:42 -0600 Subject: [PATCH 03/36] Optimizing fp8_fp16 mixedprec gemm (#1150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add delayed cvt * extend fp16 gemm_splitk instances for fp8_fp16 gemm * add f8 example * add 128 kperblk instances for fp8 * add kpb128 instance * added more instances into kpb128 * clean code * clean code * fix * fix * fixed * Update example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp Co-authored-by: Bartłomiej Kocot * Update include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp Co-authored-by: Bartłomiej Kocot * Update library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp Co-authored-by: Bartłomiej Kocot --------- Co-authored-by: Jing Zhang Co-authored-by: Bartłomiej Kocot --- example/35_splitK_gemm/CMakeLists.txt | 3 + .../splitK_gemm_xdl_fp16_fp8.cpp | 60 +++++++ .../gpu/block/blockwise_gemm_xdlops.hpp | 69 ++++---- .../impl/device_gemm_xdl_splitk_c_shuffle.hpp | 12 +- .../element/unary_element_wise_operation.hpp | 39 ----- .../gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp | 25 +-- .../threadwise_tensor_slice_transfer.hpp | 61 ++++++-- .../gpu/gemm_splitk.hpp | 6 + .../gpu/gemm_splitk/CMakeLists.txt | 79 +++++----- ...k_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp | 147 ++++++++++++++++++ 10 files changed, 370 insertions(+), 131 deletions(-) create mode 100644 example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp diff --git a/example/35_splitK_gemm/CMakeLists.txt b/example/35_splitK_gemm/CMakeLists.txt index f98308d687..5277b32f63 100644 --- a/example/35_splitK_gemm/CMakeLists.txt +++ b/example/35_splitK_gemm/CMakeLists.txt @@ -10,6 +10,9 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16) + add_example_executable(example_splitK_gemm_xdl_fp16_fp8 splitK_gemm_xdl_fp16_fp8.cpp) + add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16_fp8) + add_example_executable(example_splitK_gemm_xdl_lds_direct_load_fp16 splitK_gemm_xdl_lds_direct_load_fp16.cpp) add_example_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_lds_direct_load_fp16) diff --git a/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp new file mode 100644 index 0000000000..b93639e6c1 --- /dev/null +++ b/example/35_splitK_gemm/splitK_gemm_xdl_fp16_fp8.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/literals.hpp" + +template +using S = ck::Sequence; + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F8; +using AccDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle + // clang-format off +//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Default, ADataType, BDataType>; + +// clang-format on + +#include "run_splitK_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_splitK_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 904a96cc9f..701dd04f6c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -37,7 +37,9 @@ template + index_t KPack, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { static constexpr auto I0 = Number<0>{}; @@ -59,7 +61,8 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -295,9 +298,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -319,20 +322,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -360,7 +363,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -370,7 +373,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -398,6 +401,8 @@ template struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + KPack, + ComputeTypeA, + ComputeTypeB> { using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + KPack, + ComputeTypeA, + ComputeTypeB>; #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING using Base::a_block_desc_m0_m1_m2_k; @@ -446,9 +455,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -485,22 +494,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -550,7 +559,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -560,7 +569,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -586,7 +595,9 @@ template + LoopScheduler LoopSched, + typename ComputeTypeA = FloatA, + typename ComputeTypeB = FloatB> constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() { if constexpr(LoopSched == LoopScheduler::Default) @@ -601,7 +612,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { @@ -615,7 +628,9 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() NPerXDL, MRepeat, NRepeat, - KPack>{}; + KPack, + ComputeTypeA, + ComputeTypeB>{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp index 86c025aa6f..7f28ec7680 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -60,7 +60,9 @@ template + LoopScheduler LoopSched = make_default_loop_scheduler(), + typename LDSTypeA = ComputeType, + typename LDSTypeB = ComputeType> struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK; + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; struct Argument : public GridwiseGemm::Argument { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 70c72bf768..33c2cb6c6d 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -21,50 +21,11 @@ struct PassThroughPack2 template __host__ __device__ void operator()(Y& y, const X& x) const; - __host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::half2_t& x) const - { - // fake conversion - uint16_t t = ck::bit_cast(x); - y = ck::bit_cast(t); - } - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::f8x2_t& x) const { auto t = type_convert(x); y = type_convert(t); } - - __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::half2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::f8x2_t& y, const ck::f8x2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::float2_t& y, const ck::float2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::int8x2_t& y, const ck::int8x2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::bhalf2_t& y, const ck::bhalf2_t& x) const - { - y = x; - } - - __host__ __device__ constexpr void operator()(ck::double2_t& y, const ck::double2_t& x) const - { - y = x; - } - - constexpr const static bool is_pack2_invocable = true; }; struct PassThrough diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 6cbb834395..b52f5c51b1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -9,7 +9,6 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" @@ -96,7 +95,10 @@ template + typename ComputeTypeA = FloatC, + typename ComputeTypeB = ComputeTypeA, + typename LDSTypeA = ComputeTypeA, + typename LDSTypeB = ComputeTypeB> struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 { static constexpr auto I0 = Number<0>{}; @@ -430,7 +432,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto c_block_size = GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); - return math::max((a_block_space_size + b_block_space_size) * sizeof(ComputeType), + return math::max(a_block_space_size * sizeof(LDSTypeA) + + b_block_space_size * sizeof(LDSTypeB), c_block_size * sizeof(FloatC)); } @@ -785,7 +788,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, FloatA, - ComputeType, + LDSTypeA, decltype(a_b_k0_m_k1_grid_desc), decltype(a_b_k0_m_k1_block_desc), ABlockTransferSrcAccessOrder, @@ -815,7 +818,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, FloatB, - ComputeType, + LDSTypeB, decltype(b_b_k0_n_k1_grid_desc), decltype(b_b_k0_n_k1_block_desc), BBlockTransferSrcAccessOrder, @@ -845,8 +848,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, - ComputeType, // ComputeType A - ComputeType, // ComputeType B + LDSTypeA, + LDSTypeB, FloatAcc, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), @@ -855,7 +858,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 MRepeat, NRepeat, K1, - LoopSched>(); + LoopSched, + ComputeTypeA, + ComputeTypeB>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); @@ -863,8 +868,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 constexpr auto a_block_space_size = math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); - ComputeType* p_a_block = static_cast(p_shared_block); - ComputeType* p_b_block = static_cast(p_shared_block) + a_block_space_size; + auto p_a_block = reinterpret_cast(p_shared_block); + auto p_b_block = reinterpret_cast(p_a_block + a_block_space_size); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 2774214079..608679a4fa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -1156,27 +1158,56 @@ struct ThreadwiseTensorSliceTransfer_v4 src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); - // apply type convert src_tmp_vector.template AsType()(i) = src_buf[Number{}]; }); } - // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to - // DstData) - vector_type_maker_t dst_tmp_vector; - // TODO: if SrcData and DstData are vetor type, then static_cast may not compile - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - dst_tmp_vector.template AsType()(i) = - type_convert(src_tmp_vector.template AsType()[i]); - }); + if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; - // copy data from dst_tmp_vector into dst_buf - static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - constexpr index_t dst_offset = dst_desc.CalculateOffset( - dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + constexpr index_t pack_size = 2; - dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; - }); + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } }); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp index ebbe7c7211..863eddef24 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp @@ -189,6 +189,11 @@ void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances( DeviceGemmSplitK>>& instances); +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances( + std::vector>>& + instances); + void add_device_gemm_xdl_splitk_f16_f16_f16_comp_f8_km_kn_mn_instances( std::vector>>& @@ -352,6 +357,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v1_interwave_instances(op_ptrs); add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_v2_instances(op_ptrs); + add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt index a4d23914dd..059b6a720f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt @@ -1,42 +1,45 @@ set(GEMM_SPLITK_INSTANCES) -list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp - device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp) +list(APPEND GEMM_SPLITK_INSTANCES + device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_v2_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v1_interwave_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_v2_irregular_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_irregular_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v1_interwave_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_v2_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp + device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp + ) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp new file mode 100644 index 0000000000..0409dec369 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_kpb128_instance.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmKPadding = ck::tensor_operation::device::GemmSpecialization::KPadding; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances = std::tuple< + // clang-format off + //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 8, 16, 16, 16, 1, 1, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 8, 16, 16, 16, 1, 2, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8>, + DeviceGemmXdlSplitKCShuffle< F16, F8, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 8, 16, 16, 16, 1, 4, S<1, 8, 8, 2>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 8>, 4, F16, PipVer, LoopSche, F16, F8> + // clang-format on + >; + +void add_device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_kpb128_instances( + std::vector>>& + instances) +{ + // default + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmDefault, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // MNKPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // KPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmKPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); + + // MNPadding + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v2, + ck::LoopScheduler::Default>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Interwave>{}); + + add_device_operation_instances( + instances, + device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_irregular_kpb128_instances< + GemmMNPadding, + ck::PipelineVersion::v1, + ck::LoopScheduler::Default>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From a78be3f69e749e7a5e4d42f08415b0d13dbdb955 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:11:32 -0800 Subject: [PATCH 04/36] add docker credentials before pushing image (#1165) --- Jenkinsfile | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 071ac31439..becdc35b16 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -134,7 +134,9 @@ def buildDocker(install_prefix){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } sh 'docker images -q -f dangling=true | xargs --no-run-if-empty docker rmi' } else{ @@ -146,7 +148,9 @@ def buildDocker(install_prefix){ catch(Exception ex){ echo "Unable to locate image: ${image_name}. Building image now" retimage = docker.build("${image_name}", dockerArgs + ' .') - retimage.push() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.push() + } } } From bf98b4769714326bb8707f893c5f5a687b99825d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 13 Feb 2024 11:49:05 +0100 Subject: [PATCH 05/36] Add bilinear conv fwd and bwd data instances (#1164) --- .../CMakeLists.txt | 11 - .../24_grouped_conv_activation/CMakeLists.txt | 40 +++ ...d_conv_bwd_data_bilinear_residual_fp16.cpp | 217 ++++++++++++++ ...rouped_conv_fwd_bilinear_residual_fp16.cpp | 221 +++++++++++++++ .../grouped_conv_fwd_scaleadd_ab.inc | 2 +- .../grouped_conv_fwd_scaleadd_ab_bf16.cpp | 2 +- .../grouped_conv_fwd_scaleadd_ab_fp16.cpp | 2 +- .../grouped_conv_fwd_scaleadd_ab_fp32.cpp | 2 +- .../grouped_conv_fwd_scaleadd_ab_int8.cpp | 2 +- ...rouped_conv_fwd_scaleadd_scaleadd_relu.inc | 2 +- ...d_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp | 2 +- ...d_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp | 2 +- ...d_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp | 2 +- ...d_conv_fwd_scaleadd_scaleadd_relu_int8.cpp | 2 +- .../CMakeLists.txt | 11 - example/62_conv_fwd_activ/CMakeLists.txt | 49 ---- .../convnd_fwd_xdl_abs_fp16.cpp | 11 - .../convnd_fwd_xdl_clippedrelu_fp16.cpp | 11 - .../convnd_fwd_xdl_elu_fp16.cpp | 11 - .../convnd_fwd_xdl_leakyrelu_fp16.cpp | 11 - .../convnd_fwd_xdl_pow_fp16.cpp | 11 - .../convnd_fwd_xdl_relu_fp16.cpp | 11 - .../convnd_fwd_xdl_sigmoid_fp16.cpp | 11 - .../convnd_fwd_xdl_softrelu_fp16.cpp | 11 - .../convnd_fwd_xdl_tanh_fp16.cpp | 11 - example/62_convnd_activ/CMakeLists.txt | 17 ++ example/62_convnd_activ/binary/CMakeLists.txt | 13 + ...nd_bwd_data_xdl_bilinear_residual_fp16.cpp | 266 ++++++++++++++++++ .../convnd_fwd_xdl_bilinear_residual_fp16.cpp | 266 ++++++++++++++++++ ...aleadd_scaleadd_relu_bcasted_bias_fp16.cpp | 28 +- ...nd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp | 28 +- .../62_convnd_activ/multi_AB/CMakeLists.txt | 17 ++ .../conv_fwd_xdl_scaleadd_ab_bf16.cpp | 18 +- .../conv_fwd_xdl_scaleadd_ab_fp16.cpp | 18 +- .../conv_fwd_xdl_scaleadd_ab_fp32.cpp | 18 +- .../conv_fwd_xdl_scaleadd_ab_int8.cpp | 18 +- .../convnd_fwd_activ_multi_ab_common.hpp | 22 +- .../run_convnd_activ_example.inc} | 38 +-- example/62_convnd_activ/unary/CMakeLists.txt | 35 +++ .../unary/convnd_fwd_activ_unary_common.hpp} | 22 +- .../unary/convnd_fwd_xdl_abs_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_clippedrelu_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_elu_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_leakyrelu_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_pow_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_relu_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_sigmoid_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_softrelu_fp16.cpp | 11 + .../unary/convnd_fwd_xdl_tanh_fp16.cpp | 11 + .../element/binary_element_wise_operation.hpp | 13 +- .../cpu/reference_conv_bwd_data.hpp | 247 +++++++++++----- ...ed_conv_bwd_data_xdl_bilinear_instance.hpp | 132 +++++++++ ...grouped_conv_fwd_xdl_bilinear_instance.hpp | 131 +++++++++ ...ped_convolution_backward_data_bilinear.hpp | 150 ++++++++++ .../grouped_convolution_forward_bilinear.hpp | 177 ++++++++++++ .../CMakeLists.txt | 6 + ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 50 ++++ ...near_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 50 ++++ ...near_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 50 ++++ .../CMakeLists.txt | 7 + ...ear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 55 ++++ ...near_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 55 ++++ ...near_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 55 ++++ ...ear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 54 ++++ 64 files changed, 2471 insertions(+), 352 deletions(-) delete mode 100644 client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt create mode 100644 client_example/24_grouped_conv_activation/CMakeLists.txt create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp rename client_example/{24_grouped_convnd_fwd_scaleadd_ab => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab}/grouped_conv_fwd_scaleadd_ab.inc (99%) rename client_example/{24_grouped_convnd_fwd_scaleadd_ab => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab}/grouped_conv_fwd_scaleadd_ab_bf16.cpp (81%) rename client_example/{24_grouped_convnd_fwd_scaleadd_ab => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab}/grouped_conv_fwd_scaleadd_ab_fp16.cpp (81%) rename client_example/{24_grouped_convnd_fwd_scaleadd_ab => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab}/grouped_conv_fwd_scaleadd_ab_fp32.cpp (80%) rename client_example/{24_grouped_convnd_fwd_scaleadd_ab => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab}/grouped_conv_fwd_scaleadd_ab_int8.cpp (80%) rename client_example/{23_grouped_convnd_fwd_scaleadd_scaleadd_relu => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu}/grouped_conv_fwd_scaleadd_scaleadd_relu.inc (99%) rename client_example/{23_grouped_convnd_fwd_scaleadd_scaleadd_relu => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu}/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp (86%) rename client_example/{23_grouped_convnd_fwd_scaleadd_scaleadd_relu => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu}/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp (86%) rename client_example/{23_grouped_convnd_fwd_scaleadd_scaleadd_relu => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu}/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp (85%) rename client_example/{23_grouped_convnd_fwd_scaleadd_scaleadd_relu => 24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu}/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp (85%) delete mode 100644 client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt delete mode 100644 example/62_conv_fwd_activ/CMakeLists.txt delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp delete mode 100644 example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp create mode 100644 example/62_convnd_activ/CMakeLists.txt create mode 100644 example/62_convnd_activ/binary/CMakeLists.txt create mode 100644 example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp create mode 100644 example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp rename example/{62_conv_fwd_activ => 62_convnd_activ}/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp (93%) rename example/{62_conv_fwd_activ => 62_convnd_activ}/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp (92%) create mode 100644 example/62_convnd_activ/multi_AB/CMakeLists.txt rename example/{62_conv_fwd_activ => 62_convnd_activ}/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp (63%) rename example/{62_conv_fwd_activ => 62_convnd_activ}/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp (63%) rename example/{62_conv_fwd_activ => 62_convnd_activ}/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp (63%) rename example/{62_conv_fwd_activ => 62_convnd_activ}/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp (63%) rename example/{62_conv_fwd_activ => 62_convnd_activ}/multi_AB/convnd_fwd_activ_multi_ab_common.hpp (94%) rename example/{62_conv_fwd_activ/run_convnd_fwd_activ_example.inc => 62_convnd_activ/run_convnd_activ_example.inc} (78%) create mode 100644 example/62_convnd_activ/unary/CMakeLists.txt rename example/{62_conv_fwd_activ/convnd_fwd_activ_common.hpp => 62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp} (93%) create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp create mode 100644 example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt b/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt deleted file mode 100644 index 101a5b97ee..0000000000 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt new file mode 100644 index 0000000000..b4895db891 --- /dev/null +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -0,0 +1,40 @@ +# Fwd scaleadd scaleadd relu +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 + grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd scaleadd AB +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations) + +add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 + grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp) +target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations) +# Fwd bilinear +add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 + grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd data bilinear +add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 + grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000..bb106e8d8e --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -0,0 +1,217 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_bilinear() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple, + InDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {in.GetDeviceBuffer()}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {in_lengths}, + {in_strides}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_bilinear(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000..32ab481319 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -0,0 +1,221 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_bilinear() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple, + OutDataType, + PassThrough, + PassThrough, + Bilinear>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {out.GetDeviceBuffer()}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {out_lengths}, + {out_strides}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Bilinear{2.f, 2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_bilinear(); } diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc similarity index 99% rename from client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 923e279e7f..3f6f7b0773 100644 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp similarity index 81% rename from client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp index f384d854df..fef3f7428c 100644 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/data_type.hpp" #include "ck/utility/tuple.hpp" diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp similarity index 81% rename from client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp index fd61ef1e15..43db279191 100644 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/data_type.hpp" #include "ck/utility/tuple.hpp" diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp similarity index 80% rename from client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp index 387369c667..cccec47701 100644 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/data_type.hpp" #include "ck/utility/tuple.hpp" diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp similarity index 80% rename from client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp index 20654c7180..28674c8abe 100644 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/utility/data_type.hpp" #include "ck/utility/tuple.hpp" diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc similarity index 99% rename from client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc index e8f5529520..4e3cf69637 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp similarity index 86% rename from client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp index 559aaa8266..7a32c4f742 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp similarity index 86% rename from client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp index e1186fc81c..e3e91072b3 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp similarity index 85% rename from client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp index 02c6b3be55..e7ed96b6a0 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp similarity index 85% rename from client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp rename to client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp index dca2f3420b..9959664d2a 100644 --- a/client_example/23_grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_scaleadd_relu/grouped_conv_fwd_scaleadd_scaleadd_relu_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include diff --git a/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt b/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt deleted file mode 100644 index 38cd8b1791..0000000000 --- a/client_example/24_grouped_convnd_fwd_scaleadd_ab/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp32 grouped_conv_fwd_scaleadd_ab_fp32.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp32 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_ab_fp16 grouped_conv_fwd_scaleadd_ab_fp16.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_fp16 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_ab_bf16 grouped_conv_fwd_scaleadd_ab_bf16.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_bf16 PRIVATE composable_kernel::device_conv_operations) - -add_executable(client_grouped_convnd_fwd_scaleadd_ab_int8 grouped_conv_fwd_scaleadd_ab_int8.cpp) -target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composable_kernel::device_conv_operations) diff --git a/example/62_conv_fwd_activ/CMakeLists.txt b/example/62_conv_fwd_activ/CMakeLists.txt deleted file mode 100644 index d1f26bbfe1..0000000000 --- a/example/62_conv_fwd_activ/CMakeLists.txt +++ /dev/null @@ -1,49 +0,0 @@ -list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) -set(target 0) -foreach(gpu IN LISTS GPU_TARGETS) - if(gpu IN_LIST gpu_list AND target EQUAL 0) - add_custom_target(example_convnd_fwd_activ_xdl) - # Sigmoid - add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_sigmoid_fp16) - # Tanh - add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_tanh_fp16) - # Relu - add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_relu_fp16) - # SoftRelu - add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_softrelu_fp16) - # Abs - add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_abs_fp16) - # Pow - add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_pow_fp16) - # Clipped Relu - add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_clippedrelu_fp16) - # Leaky Relu - add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_leakyrelu_fp16) - # Elu - add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_elu_fp16) - # ScaleAdd on A and B - add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp16) - add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_fp32) - add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) - add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_conv_fwd_xdl_scaleadd_ab_int8) - # ScaleAdd ScaleAdd Relu - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) - add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) - add_example_dependencies(example_convnd_fwd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) - set(target 1) - endif() -endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp deleted file mode 100644 index 4fe0c857fa..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_abs_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp deleted file mode 100644 index feabacc5c9..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_clippedrelu_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp deleted file mode 100644 index 793102dbc6..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_elu_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::Elu; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp deleted file mode 100644 index a77408db7e..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_leakyrelu_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp deleted file mode 100644 index 2b695cf8c3..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_pow_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::Power; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp deleted file mode 100644 index e1b6e3f0cc..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_relu_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::Relu; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp deleted file mode 100644 index 350c15a787..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_sigmoid_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp deleted file mode 100644 index ec52e1a3c4..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_softrelu_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp deleted file mode 100644 index dca405669a..0000000000 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_tanh_fp16.cpp +++ /dev/null @@ -1,11 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "convnd_fwd_activ_common.hpp" - -using OutElementOp = ck::tensor_operation::element_wise::TanH; - -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; -#include "run_convnd_fwd_activ_example.inc" - -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } diff --git a/example/62_convnd_activ/CMakeLists.txt b/example/62_convnd_activ/CMakeLists.txt new file mode 100644 index 0000000000..6eaddd3ff7 --- /dev/null +++ b/example/62_convnd_activ/CMakeLists.txt @@ -0,0 +1,17 @@ +add_subdirectory(binary) +add_subdirectory(multi_AB) +add_subdirectory(unary) + +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_xdl) + # ScaleAdd ScaleAdd Relu + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16) + add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp) + add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/binary/CMakeLists.txt b/example/62_convnd_activ/binary/CMakeLists.txt new file mode 100644 index 0000000000..7c07b6bca6 --- /dev/null +++ b/example/62_convnd_activ/binary/CMakeLists.txt @@ -0,0 +1,13 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_binary_xdl) + # Bilinear residual + add_example_executable(example_convnd_fwd_xdl_bilinear_residual_fp16 convnd_fwd_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_fwd_xdl_bilinear_residual_fp16) + add_example_executable(example_convnd_bwd_data_xdl_bilinear_residual_fp16 convnd_bwd_data_xdl_bilinear_residual_fp16.cpp) + add_example_dependencies(example_convnd_activ_binary_xdl example_convnd_bwd_data_xdl_bilinear_residual_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000..f5bddf2302 --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_bwd_data_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using InElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default; + +template +using DeviceGroupedConvNDBwdDataInstance = + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< + NDimSpatial, + OutLayout, + WeiLayout, + ck::Tuple, + InLayout, + OutDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + InDataType, + OutElementOp, + WeiElementOp, + InElementOp, + ConvSpec, // ConvForwardSpecialization + true, + true, + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 2, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + S<0, 2, 1>, // BBlockTransferSrcAccessOrder + 1, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 2, // BBlockTransferDstScalarPerVector_BK1 + 0, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDBwdDataInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor out(out_g_n_k_wos_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor in_host(in_g_n_c_wis_desc); + + std::cout << "out: " << out.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "in: " << in_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + in_host.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + out.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + in_host.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // Initialize based on out_host + Tensor in_device(in_host); + + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem in_device_buf(sizeof(InDataType) * in_device.mDesc.GetElementSpaceSize()); + + out_device_buf.ToDevice(out.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + in_device_buf.ToDevice(in_device.mData.data()); + + std::array a_g_n_k_wos_lengths{}; + std::array a_g_n_k_wos_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_c_wis_lengths{}; + std::array e_g_n_c_wis_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(out_g_n_k_wos_desc.GetLengths(), a_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), a_g_n_k_wos_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(in_g_n_c_wis_desc.GetLengths(), e_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), e_g_n_c_wis_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {in_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + out_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + in_device_buf.GetDeviceBuffer(), + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_c_wis_lengths}, + std::array, NumDs>{e_g_n_c_wis_strides}, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + out_element_op, + wei_element_op, + in_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetInputByte() / sizeof(InDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {in_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvBwdData(); + + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_host, + wei, + out, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + in_device_buf.FromDevice(in_device.mData.data()); + + return ck::utils::check_err(in_device.mData, in_host.mData); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp new file mode 100644 index 0000000000..ae1ebcb2cd --- /dev/null +++ b/example/62_convnd_activ/binary/convnd_fwd_xdl_bilinear_residual_fp16.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +constexpr ck::index_t NDimSpatial = 3; +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using OutDataType = ck::half_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + +using OutElementOp = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8>; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; + +namespace { +// Use custom implementation to pass two more tensors for post op +template +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + constexpr ck::index_t NumDs = 1; + Tensor in(in_g_n_c_wis_desc); + Tensor wei(wei_g_k_c_xs_desc); + Tensor out_host(out_g_n_k_wos_desc); + + std::cout << "in: " << in.mDesc << std::endl; + std::cout << "wei: " << wei.mDesc << std::endl; + std::cout << "out: " << out_host.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + in.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + wei.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + out_host.GenerateTensorValue(GeneratorTensor_3{-0.05, 0.05}); + } + + // Initialize based on out_host + Tensor out_device(out_host); + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out_device.mData.data()); + + std::array a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + std::array b_g_k_c_xs_strides{}; + std::array e_g_n_k_wos_lengths{}; + std::array e_g_n_k_wos_strides{}; + std::array conv_filter_strides{}; + std::array conv_filter_dilations{}; + std::array input_left_pads{}; + std::array input_right_pads{}; + + auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + // Use output as D + const std::array ds = {out_device_buf.GetDeviceBuffer()}; + + auto conv = DeviceConvNDFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument( + in_device_buf.GetDeviceBuffer(), + wei_device_buf.GetDeviceBuffer(), + ds, + out_device_buf.GetDeviceBuffer(), + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + std::array, NumDs>{e_g_n_k_wos_lengths}, + std::array, NumDs>{e_g_n_k_wos_strides}, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error("The device op with the specified compilation parameters does " + "not support this convolution problem."); + } + + float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = + conv_param.GetFlops() + 3 * conv_param.GetOutputByte() / sizeof(OutDataType); + std::size_t num_btype = conv_param.GetByte() + + conv_param.GetOutputByte(); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv.GetTypeString() << std::endl; + + if(do_verification) + { + std::array, NumDs> d_tensors = {out_host}; + auto ref_conv = + ck::tensor_operation::host::ReferenceConvFwd(); + + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in, + wei, + out_host, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_, + in_element_op, + wei_element_op, + out_element_op, + {}, + {}, + d_tensors); + + ref_invoker.Run(ref_argument); + + out_device_buf.FromDevice(out_device.mData.data()); + + return ck::utils::check_err(out_device, out_host, "Error: incorrect results!"); + } + + return true; +} + +} // namespace + +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp similarity index 93% rename from example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp rename to example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp index 196636f8b5..d101fd59bd 100644 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -97,7 +97,7 @@ using DeviceGroupedConvNDFwdInstance = S<1, 32, 1, 8>, 8>; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; namespace { // Use custom implementation to pass two more tensors for post op @@ -109,16 +109,16 @@ template -bool run_grouped_conv_fwd(bool do_verification, - int init_method, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - const HostTensorDescriptor& in_g_n_c_wis_desc, - const HostTensorDescriptor& wei_g_k_c_xs_desc, - const HostTensorDescriptor& out_g_n_k_wos_desc, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) { constexpr ck::index_t NumDs = 2; const ck::index_t G = out_g_n_k_wos_desc.GetLengths()[0]; @@ -289,6 +289,6 @@ bool run_grouped_conv_fwd(bool do_verification, } // namespace -#include "run_convnd_fwd_activ_example.inc" +#include "run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp similarity index 92% rename from example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp rename to example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp index 572c4bb7a5..f784655cc5 100644 --- a/example/62_conv_fwd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp +++ b/example/62_convnd_activ/convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -94,7 +94,7 @@ using DeviceGroupedConvNDFwdInstance = S<1, 32, 1, 8>, 8>; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; namespace { // Use custom implementation to pass two more tensors for post op @@ -106,16 +106,16 @@ template -bool run_grouped_conv_fwd(bool do_verification, - int init_method, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - const HostTensorDescriptor& in_g_n_c_wis_desc, - const HostTensorDescriptor& wei_g_k_c_xs_desc, - const HostTensorDescriptor& out_g_n_k_wos_desc, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) { constexpr ck::index_t NumDs = 2; Tensor in(in_g_n_c_wis_desc); @@ -265,6 +265,6 @@ bool run_grouped_conv_fwd(bool do_verification, } // namespace -#include "run_convnd_fwd_activ_example.inc" +#include "run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/multi_AB/CMakeLists.txt b/example/62_convnd_activ/multi_AB/CMakeLists.txt new file mode 100644 index 0000000000..c89c82d384 --- /dev/null +++ b/example/62_convnd_activ/multi_AB/CMakeLists.txt @@ -0,0 +1,17 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_multi_ab_xdl) + # ScaleAdd on A and B + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp16 conv_fwd_xdl_scaleadd_ab_fp16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_fp32 conv_fwd_xdl_scaleadd_ab_fp32.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_fp32) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_bf16 conv_fwd_xdl_scaleadd_ab_bf16.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_bf16) + add_example_executable(example_conv_fwd_xdl_scaleadd_ab_int8 conv_fwd_xdl_scaleadd_ab_int8.cpp) + add_example_dependencies(example_convnd_activ_multi_ab_xdl example_conv_fwd_xdl_scaleadd_ab_int8) + set(target 1) + endif() +endforeach() diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp similarity index 63% rename from example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp rename to example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp index 7993552210..b7ceee03b8 100644 --- a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_bf16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_activ_multi_ab_common.hpp" @@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; -#include "../run_convnd_fwd_activ_example.inc" +#include "../run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp similarity index 63% rename from example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp rename to example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp index 696bc0c3fe..08d8a89669 100644 --- a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_activ_multi_ab_common.hpp" @@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; -#include "../run_convnd_fwd_activ_example.inc" +#include "../run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp similarity index 63% rename from example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp rename to example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp index a95f5e1347..bef9980b3e 100644 --- a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_activ_multi_ab_common.hpp" @@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; -#include "../run_convnd_fwd_activ_example.inc" +#include "../run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp similarity index 63% rename from example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp rename to example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp index 4fde3a722d..2b132b9121 100644 --- a/example/62_conv_fwd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp +++ b/example/62_convnd_activ/multi_AB/conv_fwd_xdl_scaleadd_ab_int8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include "convnd_fwd_activ_multi_ab_common.hpp" @@ -14,13 +14,13 @@ using BDataTypes = ck::Tuple; using InElementOp = ck::tensor_operation::element_wise::ScaleAdd; using WeiElementOp = ck::tensor_operation::element_wise::ScaleAdd; -using DeviceGroupedConvNDFwdActivInstance = DeviceGroupedConvNDMultiABFwdInstance; +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDMultiABFwdInstance; -#include "../run_convnd_fwd_activ_example.inc" +#include "../run_convnd_activ_example.inc" -int main(int argc, char* argv[]) { return !run_convnd_fwd_example(argc, argv); } +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp similarity index 94% rename from example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp rename to example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp index f61a91748f..2626843ed4 100644 --- a/example/62_conv_fwd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp +++ b/example/62_convnd_activ/multi_AB/convnd_fwd_activ_multi_ab_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -100,16 +100,16 @@ template -bool run_grouped_conv_fwd(bool do_verification, - int init_method, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - const HostTensorDescriptor& in_g_n_c_wis_desc, - const HostTensorDescriptor& wei_g_k_c_xs_desc, - const HostTensorDescriptor& out_g_n_k_wos_desc, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) { constexpr ck::index_t NumAs = 2; constexpr ck::index_t NumBs = 2; diff --git a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc b/example/62_convnd_activ/run_convnd_activ_example.inc similarity index 78% rename from example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc rename to example/62_convnd_activ/run_convnd_activ_example.inc index aa547c870a..5a402e41cd 100644 --- a/example/62_conv_fwd_activ/run_convnd_fwd_activ_example.inc +++ b/example/62_convnd_activ/run_convnd_activ_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -11,7 +11,7 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } -bool run_convnd_fwd_example(int argc, char* argv[]) +bool run_convnd_example(int argc, char* argv[]) { print_helper_msg(); @@ -63,23 +63,23 @@ bool run_convnd_fwd_example(int argc, char* argv[]) ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( conv_param); - return run_grouped_conv_fwd(do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + return run_grouped_conv(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); }; if(conv_param.num_dim_spatial_ == 3) diff --git a/example/62_convnd_activ/unary/CMakeLists.txt b/example/62_convnd_activ/unary/CMakeLists.txt new file mode 100644 index 0000000000..94ffb3661c --- /dev/null +++ b/example/62_convnd_activ/unary/CMakeLists.txt @@ -0,0 +1,35 @@ +list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list AND target EQUAL 0) + add_custom_target(example_convnd_activ_unary_xdl) + # Sigmoid + add_example_executable(example_convnd_fwd_xdl_sigmoid_fp16 convnd_fwd_xdl_sigmoid_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_sigmoid_fp16) + # Tanh + add_example_executable(example_convnd_fwd_xdl_tanh_fp16 convnd_fwd_xdl_tanh_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_tanh_fp16) + # Relu + add_example_executable(example_convnd_fwd_xdl_relu_fp16 convnd_fwd_xdl_relu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_relu_fp16) + # SoftRelu + add_example_executable(example_convnd_fwd_xdl_softrelu_fp16 convnd_fwd_xdl_softrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_softrelu_fp16) + # Abs + add_example_executable(example_convnd_fwd_xdl_abs_fp16 convnd_fwd_xdl_abs_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_abs_fp16) + # Pow + add_example_executable(example_convnd_fwd_xdl_pow_fp16 convnd_fwd_xdl_pow_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_pow_fp16) + # Clipped Relu + add_example_executable(example_convnd_fwd_xdl_clippedrelu_fp16 convnd_fwd_xdl_clippedrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_clippedrelu_fp16) + # Leaky Relu + add_example_executable(example_convnd_fwd_xdl_leakyrelu_fp16 convnd_fwd_xdl_leakyrelu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_leakyrelu_fp16) + # Elu + add_example_executable(example_convnd_fwd_xdl_elu_fp16 convnd_fwd_xdl_elu_fp16.cpp) + add_example_dependencies(example_convnd_activ_unary_xdl example_convnd_fwd_xdl_elu_fp16) + set(target 1) + endif() +endforeach() diff --git a/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp similarity index 93% rename from example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp rename to example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp index dbeaa426c5..4669465bf4 100644 --- a/example/62_conv_fwd_activ/convnd_fwd_activ_common.hpp +++ b/example/62_convnd_activ/unary/convnd_fwd_activ_unary_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -102,16 +102,16 @@ template -bool run_grouped_conv_fwd(bool do_verification, - int init_method, - bool time_kernel, - const ck::utils::conv::ConvParam& conv_param, - const HostTensorDescriptor& in_g_n_c_wis_desc, - const HostTensorDescriptor& wei_g_k_c_xs_desc, - const HostTensorDescriptor& out_g_n_k_wos_desc, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) +bool run_grouped_conv(bool do_verification, + int init_method, + bool time_kernel, + const ck::utils::conv::ConvParam& conv_param, + const HostTensorDescriptor& in_g_n_c_wis_desc, + const HostTensorDescriptor& wei_g_k_c_xs_desc, + const HostTensorDescriptor& out_g_n_k_wos_desc, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) { Tensor in(in_g_n_c_wis_desc); Tensor wei(wei_g_k_c_xs_desc); diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp new file mode 100644 index 0000000000..e621c3b15e --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_abs_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::UnaryAbs; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp new file mode 100644 index 0000000000..bb26fb292c --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_clippedrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::ClippedRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp new file mode 100644 index 0000000000..1aa4d5d4fa --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_elu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Elu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp new file mode 100644 index 0000000000..659c36ec8d --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_leakyrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::LeakyRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp new file mode 100644 index 0000000000..5efa0f8f9c --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_pow_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Power; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp new file mode 100644 index 0000000000..84b7c598ee --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_relu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Relu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp new file mode 100644 index 0000000000..53e06d387f --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_sigmoid_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::Sigmoid; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp new file mode 100644 index 0000000000..a2d76da4e2 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_softrelu_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::SoftRelu; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp new file mode 100644 index 0000000000..d60c005e09 --- /dev/null +++ b/example/62_convnd_activ/unary/convnd_fwd_xdl_tanh_fp16.cpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_activ_unary_common.hpp" + +using OutElementOp = ck::tensor_operation::element_wise::TanH; + +using DeviceGroupedConvNDActivInstance = DeviceGroupedConvNDFwdInstance; +#include "../run_convnd_activ_example.inc" + +int main(int argc, char* argv[]) { return !run_convnd_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 95048469cd..ba2e0057d9 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -165,7 +165,7 @@ struct Subtract struct Bilinear { - Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; template __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; @@ -184,6 +184,14 @@ struct Bilinear y = alpha_ * x0 + beta_ * x1; }; + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + template <> __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const @@ -221,7 +229,8 @@ struct Bilinear __host__ __device__ constexpr void operator()( std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const { - y = type_convert(x0 + ck::type_convert(x1)); + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); }; float alpha_; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp index bfb8b48187..a41f952408 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -25,25 +25,35 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvBwdData : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { - Argument(Tensor& input, - const Tensor& weight, - const Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + Argument( + Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const std::array, NumAElementwiseTensor>& elementwise_a_tensors, + const std::array, NumBElementwiseTensor>& elementwise_b_tensors, + const std::array, NumDElementwiseTensor>& elementwise_d_tensors) : input_{input}, weight_{weight}, output_{output}, + elementwise_a_tensors_{elementwise_a_tensors}, + elementwise_b_tensors_{elementwise_b_tensors}, + elementwise_d_tensors_{elementwise_d_tensors}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, @@ -58,6 +68,10 @@ struct ReferenceConvBwdData : public device::BaseOperator const Tensor& weight_; const Tensor& output_; + const std::array, NumAElementwiseTensor>& elementwise_a_tensors_; + const std::array, NumBElementwiseTensor>& elementwise_b_tensors_; + const std::array, NumDElementwiseTensor>& elementwise_d_tensors_; + std::vector conv_strides_; std::vector conv_dilations_; std::vector in_left_pads_; @@ -106,26 +120,46 @@ struct ReferenceConvBwdData : public device::BaseOperator { for(std::size_t k = 0; k < K; ++k) { - float v_out = 0; - float v_wei = 0; + OutDataType v_out; + WeiDataType v_wei; - arg.out_element_op_( - v_out, ck::type_convert(arg.output_(g, n, k, wo))); + ExecuteElementwiseOp(arg.out_element_op_, + arg.elementwise_a_tensors_, + Number{}, + v_out, + arg.output_(g, n, k, wo), + g, + n, + k, + wo); + ExecuteElementwiseOp(arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, + v_wei, + arg.weight_(g, k, c, x), + g, + k, + c, + x); - arg.wei_element_op_( - v_wei, ck::type_convert(arg.weight_(g, k, c, x))); - - v_acc += v_out * v_wei; + v_acc += ck::type_convert(v_out) * + ck::type_convert(v_wei); } } } } - float v_in; - - arg.in_element_op_(v_in, v_acc); - - arg.input_(g, n, c, wi) = ck::type_convert(v_in); + InDataType v_acc_converted = ck::type_convert(v_acc); + InDataType& v_in = arg.input_(g, n, c, wi); + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_in, + v_acc_converted, + g, + n, + c, + wi); }; make_ParallelTensorFunctor(f_ncw, @@ -175,20 +209,34 @@ struct ReferenceConvBwdData : public device::BaseOperator { for(std::size_t k = 0; k < K; ++k) { - float v_out = 0; - float v_wei = 0; + OutDataType v_out; + WeiDataType v_wei; - arg.out_element_op_( + ExecuteElementwiseOp( + arg.out_element_op_, + arg.elementwise_a_tensors_, + Number{}, v_out, - ck::type_convert( - arg.output_(g, n, k, ho, wo))); - - arg.wei_element_op_( + arg.output_(g, n, k, ho, wo), + g, + n, + k, + ho, + wo); + ExecuteElementwiseOp( + arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, v_wei, - ck::type_convert( - arg.weight_(g, k, c, y, x))); + arg.weight_(g, k, c, y, x), + g, + k, + c, + y, + x); - v_acc += v_out * v_wei; + v_acc += ck::type_convert(v_out) * + ck::type_convert(v_wei); } } } @@ -197,11 +245,18 @@ struct ReferenceConvBwdData : public device::BaseOperator } } - float v_in; - - arg.in_element_op_(v_in, v_acc); - - arg.input_(g, n, c, hi, wi) = ck::type_convert(v_in); + InDataType v_acc_converted = ck::type_convert(v_acc); + InDataType& v_in = arg.input_(g, n, c, hi, wi); + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_in, + v_acc_converted, + g, + n, + c, + hi, + wi); }; make_ParallelTensorFunctor(f_nchw, @@ -270,20 +325,37 @@ struct ReferenceConvBwdData : public device::BaseOperator { for(std::size_t k = 0; k < K; ++k) { - float v_out = 0; - float v_wei = 0; + OutDataType v_out; + WeiDataType v_wei; - arg.out_element_op_( + ExecuteElementwiseOp( + arg.out_element_op_, + arg.elementwise_a_tensors_, + Number{}, v_out, - ck::type_convert(arg.output_( - g, n, k, do_, ho, wo))); - - arg.wei_element_op_( + arg.output_(g, n, k, do_, ho, wo), + g, + n, + k, + do_, + ho, + wo); + ExecuteElementwiseOp( + arg.wei_element_op_, + arg.elementwise_b_tensors_, + Number{}, v_wei, - ck::type_convert( - arg.weight_(g, k, c, z, y, x))); + arg.weight_(g, k, c, z, y, x), + g, + k, + c, + z, + y, + x); - v_acc += v_out * v_wei; + v_acc += + ck::type_convert(v_out) * + ck::type_convert(v_wei); } } } @@ -295,11 +367,19 @@ struct ReferenceConvBwdData : public device::BaseOperator } } - float v_in; - - arg.in_element_op_(v_in, v_acc); - - arg.input_(g, n, c, di, hi, wi) = ck::type_convert(v_in); + InDataType v_acc_converted = ck::type_convert(v_acc); + InDataType& v_in = arg.input_(g, n, c, di, hi, wi); + ExecuteElementwiseOp(arg.in_element_op_, + arg.elementwise_d_tensors_, + Number{}, + v_in, + v_acc_converted, + g, + n, + c, + di, + hi, + wi); }; make_ParallelTensorFunctor(f_ncdhw, @@ -325,6 +405,36 @@ struct ReferenceConvBwdData : public device::BaseOperator } }; + template + static void ExecuteElementwiseOp(ElementwiseOp& elementwise_op, + ElementwiseTensor& elementwise_tensors, + NumTensor, + T& y, + const T& x, + Args... dims) + { + if constexpr(NumTensor::value == 0) + { + elementwise_op(y, x); + } + else if constexpr(NumTensor::value == 1) + { + elementwise_op(y, x, elementwise_tensors[0](dims...)); + } + else if constexpr(NumTensor::value == 2) + { + elementwise_op(y, x, elementwise_tensors[0](dims...), elementwise_tensors[1](dims...)); + } + else + { + throw std::runtime_error("ElementOp not supported in reference."); + } + } + static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check @@ -333,16 +443,20 @@ struct ReferenceConvBwdData : public device::BaseOperator bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - static auto MakeArgument(Tensor& input, - const Tensor& weight, - const Tensor& output, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) + static auto MakeArgument( + Tensor& input, + const Tensor& weight, + const Tensor& output, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const std::array, NumAElementwiseTensor>& elementwise_a_tensors = {}, + const std::array, NumBElementwiseTensor>& elementwise_b_tensors = {}, + const std::array, NumDElementwiseTensor>& elementwise_d_tensors = {}) { return Argument{input, weight, @@ -353,7 +467,10 @@ struct ReferenceConvBwdData : public device::BaseOperator input_right_pads, in_element_op, wei_element_op, - out_element_op}; + out_element_op, + elementwise_a_tensors, + elementwise_b_tensors, + elementwise_d_tensors}; } static auto MakeInvoker() { return Invoker{}; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp new file mode 100644 index 0000000000..93a1ef2096 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// f16_f16_f32_f16 +template +using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +// f32_f32_f32_f32 +template +using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> + // clang-format on + >; + +// f16_f16_f16_comp_f8 +template +using device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp new file mode 100644 index 0000000000..3c689990aa --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_bilinear_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_bilinear_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_bilinear_int8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp new file mode 100644 index 0000000000..595288e193 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_bilinear.hpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD< + NumDimSpatial, + OutLayout, + WeiLayout, + Tuple, + InLayout, + OutDataType, + WeiDataType, + Tuple, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear, + ComputeTypeA, + ComputeTypeB>> +{ + using DeviceOp = + DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + Tuple, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Bilinear, + ComputeTypeA, + ComputeTypeB>; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp new file mode 100644 index 0000000000..c8375da6e1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bilinear.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Bilinear = ck::tensor_operation::element_wise::Bilinear; + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple, + int8_t, + PassThrough, + PassThrough, + Bilinear>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v && + DLayouts::Size() == 1 && is_same_v, NDHWGK>) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances( + op_ptrs); + } +#endif + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt new file mode 100644 index 0000000000..e1cb975291 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/CMakeLists.txt @@ -0,0 +1,6 @@ +set(GROUPED_CONV3D_BWD_DATA_BILINEAR + xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp) + +add_instance_library(device_grouped_conv3d_bwd_data_bilinear_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..c25c481c05 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances(instances, + device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances< + 3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..f61083e791 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..2e014ae760 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_bilinear/xdl/device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_bilinear_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt new file mode 100644 index 0000000000..49706588d6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/CMakeLists.txt @@ -0,0 +1,7 @@ +set(GROUPED_CONV3D_FWD_BILINEAR + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_bilinear_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..4f5461d12b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple, + BF16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..e3a4de83f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple, + F16, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..fc3ee53570 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple, + F32, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp new file mode 100644 index 0000000000..eccdcff845 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bilinear/xdl/device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple, + int8_t, + PassThrough, + PassThrough, + Bilinear>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_bilinear_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 1e73adbc2809fb582c40f91daa8ecd7cd6737aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 13 Feb 2024 17:04:36 +0100 Subject: [PATCH 06/36] Add optimized blockwise gemm using ck wrapper (#1157) * Add optimized blockwise gemm using ck wrapper * Add basic gemm example * Update docs * Add tutorial for gemm using ck wrapper * Add perf note * edits * Fix cmake * Fixes --------- Co-authored-by: Lisa Delaney --- client_example/25_wrapper/CMakeLists.txt | 8 + client_example/25_wrapper/README.md | 177 +++++++++ .../25_wrapper/wrapper_basic_gemm.cpp | 216 ++++++++++ client_example/25_wrapper/wrapper_img2col.cpp | 42 +- .../25_wrapper/wrapper_optimized_gemm.cpp | 308 ++++++++++++++ docs/wrapper.rst | 10 +- include/ck/wrapper/operations/copy.hpp | 68 ++-- include/ck/wrapper/operations/gemm.hpp | 98 +++-- include/ck/wrapper/tensor.hpp | 4 +- .../traits/blockwise_gemm_xdl_traits.hpp | 47 ++- include/ck/wrapper/utils/kernel_utils.hpp | 14 + include/ck/wrapper/utils/layout_utils.hpp | 105 ++++- include/ck/wrapper/utils/tensor_partition.hpp | 290 +++++++++----- test/wrapper/CMakeLists.txt | 27 +- test/wrapper/test_gemm.cpp | 257 ------------ .../{test_copy.cpp => test_wrapper_copy.cpp} | 27 +- test/wrapper/test_wrapper_gemm.cpp | 376 ++++++++++++++++++ ...est_layout.cpp => test_wrapper_layout.cpp} | 2 +- ...rtition.cpp => test_wrapper_partition.cpp} | 33 +- ...est_tensor.cpp => test_wrapper_tensor.cpp} | 0 20 files changed, 1597 insertions(+), 512 deletions(-) create mode 100644 client_example/25_wrapper/README.md create mode 100644 client_example/25_wrapper/wrapper_basic_gemm.cpp create mode 100644 client_example/25_wrapper/wrapper_optimized_gemm.cpp create mode 100644 include/ck/wrapper/utils/kernel_utils.hpp delete mode 100644 test/wrapper/test_gemm.cpp rename test/wrapper/{test_copy.cpp => test_wrapper_copy.cpp} (83%) create mode 100644 test/wrapper/test_wrapper_gemm.cpp rename test/wrapper/{test_layout.cpp => test_wrapper_layout.cpp} (99%) rename test/wrapper/{test_partition.cpp => test_wrapper_partition.cpp} (79%) rename test/wrapper/{test_tensor.cpp => test_wrapper_tensor.cpp} (100%) diff --git a/client_example/25_wrapper/CMakeLists.txt b/client_example/25_wrapper/CMakeLists.txt index eb3be0e6c8..fdfc1d8d2e 100644 --- a/client_example/25_wrapper/CMakeLists.txt +++ b/client_example/25_wrapper/CMakeLists.txt @@ -2,3 +2,11 @@ add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrap target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations) add_executable(client_wrapper_img2col wrapper_img2col.cpp) target_link_libraries(client_wrapper_img2col PRIVATE composable_kernel::device_other_operations) +if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR + GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR + GPU_TARGETS MATCHES "gfx942") + add_executable(client_wrapper_basic_gemm wrapper_basic_gemm.cpp) + target_link_libraries(client_wrapper_basic_gemm PRIVATE composable_kernel::device_other_operations) + add_executable(client_wrapper_optimized_gemm wrapper_optimized_gemm.cpp) + target_link_libraries(client_wrapper_optimized_gemm PRIVATE composable_kernel::device_other_operations) +endif() diff --git a/client_example/25_wrapper/README.md b/client_example/25_wrapper/README.md new file mode 100644 index 0000000000..eba3de017f --- /dev/null +++ b/client_example/25_wrapper/README.md @@ -0,0 +1,177 @@ +# Composable Kernel wrapper GEMM tutorial + +This tutorial demonstrates how to implement matrix multiplication using Composable Kernel (CK) +wrapper. We present the base version of GEMM without most of the available optimizations; however, +it's worth noting that CK has kernels with different optimizations. + +To implement these optimizations, you can use the CK wrapper or directly use available instances in +CK. You can also refer to the +[optimized GEMM example](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_optimized_gemm.cpp), +that uses CK wrapper based on the +[`gridwise_gemm_xdlops_v2r3`](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp) implementation. + +The kernel definition should look similar to: + +```cpp +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +``` + +We pass pointers to global memory and matrix dimensions via arguments. Additionally, we pass +selected lengths of processed data through each block (`tile_shape`) and thread layout +(`thread_layout`). For compilation time parameters, we define the data type, +[traits for the GEMM operation](https://github.com/ROCm/composable_kernel/blob/develop/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp) +and scalar per vector value during copy. + +Step 1: Create layouts for global and LDS memory. + +```cpp + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); +``` + +We pad layouts for global tensors in case M, N, and K are not divisible by `MPerBlock`, `NPerBlock`, or +`KPerBlock`. + +Step 2: Create tensors for global and LDS memory. + +```cpp + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + + // Allocate LDS memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); +``` + +We must specify parameters for copy and convert block indexes to tuple: + +```cpp + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; +``` + +We create a local tile (per block) and local partitions (per thread) for the global memory (`C`). We also +define and clear an output register (`c_vgpr_reg`) for the accumulation. + +```cpp + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); +``` + +We use two specific functions for `blockwise_gemm`: `make_blockwise_gemm_xdl_c_local_partition` and +`make_blockwise_gemm_xdl_c_vgpr`. This helps to choose the appropriate partition for the `C` output +and define tensors with specific layouts for `blockwise_gemm`. In the following step, we use only +generic functions for the CK wrapper. + +Step 3: Create the compute loop. + +```cpp + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to LDS. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise GEMM. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); +``` + +Loop iterate over `K / KPerBlock`. Each time a local tile is created for A and B tensors (tensor per block), +data is copied from global memory to LDS. The `blockwise_gemm` function performs the GEMM +operation on `a_lds_tensor` and `b_lds_tensor`, and stores results in `c_vgpr_reg`. + +The end result from `c_vgpr_reg` is stored in the `C` local partition (tensor per thread): + +```cpp + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +``` + +If you want to dive deep into the details, you can find the entire example +[here](https://github.com/ROCm/composable_kernel/blob/develop/client_example/25_wrapper/wrapper_basic_gemm.cpp). diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp new file mode 100644 index 0000000000..1f1a4de751 --- /dev/null +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + + // Specify layouts for global memory. + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Specify layouts for tiles. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); + constexpr auto c_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); + // Apply padding for global memory. + auto a_global_layout_padded = ck::wrapper::pad(a_global_layout, shape(a_tile_layout)); + auto b_global_layout_padded = ck::wrapper::pad(b_global_layout, shape(b_tile_layout)); + auto c_global_layout_padded = ck::wrapper::pad(c_global_layout, shape(c_tile_layout)); + // Make tensors for global memory. + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_global_layout_padded); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_global_layout_padded); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_global_layout_padded); + // Allocate lds memory. + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; + // Make tensors for lds memory. + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + // Specify block index as tuple. + const auto block_idxs = ck::make_tuple(static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + // Specify access parameters for copy. + using DimAccessOrder = ck::Tuple, ck::Number<1>>; + constexpr ck::index_t vector_dim = 1; + // Create tile and partition for C. Use specific function for blockwise_gemm to assign the + // appropriate partitions. + auto c_global_local_tile = ck::wrapper::make_local_tile( + c_global_tensor, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Create C vgpr to accumulate results. + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + // Clear C vgpr. + ck::wrapper::clear(c_vgpr_reg); + + // Iterate over K with KPerBlock step. + const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); + ck::index_t i = 0; + do + { + // Get KPerBlock slice. + const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); + auto a_global_tensor_k_slice = a_global_tensor(ck::wrapper::slice(), k_slice); + auto b_global_tensor_k_slice = b_global_tensor(ck::wrapper::slice(), k_slice); + // Create local tiles for A and B. + auto a_global_local_tile = ck::wrapper::make_local_tile( + a_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); + auto b_global_local_tile = ck::wrapper::make_local_tile( + b_global_tensor_k_slice, + tile_shape, + block_idxs, + make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); + // Copy from global to lds. + ck::wrapper::blockwise_copy( + a_global_local_tile, a_lds_tensor, thread_layout); + ck::wrapper::blockwise_copy( + b_global_local_tile, b_lds_tensor, thread_layout); + // Synchronize lds. + ck::block_sync_lds(); + // Execute blockwise gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ++i; + } while(i < num_loop); + // Copy vgpr results to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}), + ck::make_tuple(ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index 35074be4c1..2a4034d62f 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -15,6 +15,7 @@ #include "ck/wrapper/layout.hpp" #include "ck/wrapper/tensor.hpp" #include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" static constexpr ck::index_t NumDimSpatial = 3; using DataType = float; @@ -36,21 +37,20 @@ struct SimpleDeviceMem void* p_mem_; }; -// Test copy from Global to Global through LDS and VGPR -template -__global__ void DeviceImageToColumnPad0(InputTensor input_tensor, - OutputTensor output_tensor, - const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ +DeviceImageToColumnPad0(InputTensor input_tensor, + OutputTensor output_tensor, + const BlockShape tile_shape, + const ThreadLayout thread_layout) { - const ck::index_t block_idx = static_cast(blockIdx.x); + // grid layout (dim1, dim0) + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.y), static_cast(blockIdx.x)); // Get local tiles for global memory - auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); - auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); + auto output_local_tile = ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); // Get partition per thread const auto input_local_partition = @@ -112,9 +112,11 @@ void PerformImageToColumnPad0(const ck::index_t G, SimpleDeviceMem out_buf(ck::wrapper::size(out_layout) * sizeof(DataType)); // User can choose appropriate number of threads and sizes per block - const auto thread_layout = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}), + ck::make_tuple(ck::Number<16>{}, ck::Number<1>{})); // This example doesn't support padding, user should select tile sizes - // which divides the shape completely + // which are divisible by the shape. const auto tile_shape = ck::make_tuple(ck::Number<32>{}, ck::Number<64>{}); // Create buffers for global memory @@ -123,10 +125,11 @@ void PerformImageToColumnPad0(const ck::index_t G, auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), out_layout); - const ck::index_t grid_size = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), - ck::wrapper::size<0>(tile_shape)) * - ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), - ck::wrapper::size<1>(tile_shape)); + // grid layout (dim1, dim0) + const ck::index_t grid_size_x = ck::math::integer_divide_ceil(ck::wrapper::size<1>(in_layout), + ck::wrapper::size<1>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil(ck::wrapper::size<0>(in_layout), + ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, - dim3(grid_size), + dim3(grid_size_x, grid_size_y, 1), dim3(ck::wrapper::size(thread_layout)), 0, input_tensor_global, @@ -178,3 +181,4 @@ int main(int argc, char* argv[]) {1, 1, 1} /*filter_dilations*/); return 0; } +// MI100 Perf: 0.255178 ms, 1698.9 GB/s, diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp new file mode 100644 index 0000000000..ddf01de612 --- /dev/null +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -0,0 +1,308 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + // Create layouts for global memory + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + // Apply padding + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + // Create tensors for global memory + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + // Create layouts and tensors for lds memory. + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + K0PerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + K0PerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + // Create tile and partition for C global memory. Use specific gemm + // functions to get appropriate layouts. + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + // Define and clear c vgpr register + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + // Local partitions for lds memory + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + // Lamda to slice tensor, then create local tile and partition + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + // Copy first values to lds + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + // Pipeline loop + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + // Skip if only tile should be processed + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + // Copy data to A vgpr. + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + // Synchronize. + ck::block_sync_lds(); + // Copy data to B vgpr. + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + // Perform gemm. + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Synchronize + ck::block_sync_lds(); + // Copy data to A and B lds tiles. + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + // Handle tail. + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + // Store data from C vgpr to C global memory. + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + SimpleDeviceMem a_mem(M * K * sizeof(DataType)); + SimpleDeviceMem b_mem(K * N * sizeof(DataType)); + SimpleDeviceMem c_mem(M * N * sizeof(DataType)); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; +} + +int main(int argc, char* argv[]) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 3840, 4096, 4096, tile_shape, thread_layout); + return 0; +} +// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s, diff --git a/docs/wrapper.rst b/docs/wrapper.rst index c64c0bf17f..39e2fd0bbd 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -12,10 +12,6 @@ Wrapper Description ------------------------------------- -.. note:: - - The wrapper is under development and its functionality is limited. - The CK library provides a lightweight wrapper for more complex operations implemented in the library. @@ -54,9 +50,15 @@ Output:: 2 6 10 14 18 22 26 30 +Tutorials: + +* `GEMM tutorial `_ + Advanced examples: * `Image to column `_ +* `Basic gemm `_ +* `Optimized gemm `_ ------------------------------------- Layout diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index 614dfd758e..5f64031ebe 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -61,12 +61,12 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) decltype(dim_access_order), VectorDim, ScalarPerVector, - Sequence, - Sequence>{in_grid_desc, - make_tuple(src_tensor.GetMultiIdxOffsets()), - out_grid_desc, - make_tuple(dst_tensor.GetMultiIdxOffsets()), - tensor_operation::element_wise::PassThrough{}}; + Sequence, + Sequence>{in_grid_desc, + make_tuple(src_tensor.GetMultiIdxOffsets()), + out_grid_desc, + make_tuple(dst_tensor.GetMultiIdxOffsets()), + tensor_operation::element_wise::PassThrough{}}; transfer.Run(tie(in_grid_desc), tie(src_tensor.GetBuffer()), @@ -104,37 +104,25 @@ __device__ void copy(const SrcTensorType& src_tensor, DstTensorType& dst_tensor) else if constexpr(SrcTensorType::IsDynamicBuffer && !DstTensorType::IsDynamicBuffer) { // Perform copy from DynamicBuffer to StaticBuffer - const auto src_dst_slice_origin = + const auto dst_slice_origin_idxs = generate_tuple([&](auto) { return I0; }, Number{}); - constexpr auto src_vector_tensor_lengths = generate_sequence_v2( - [&](auto I) { - if constexpr(I == VectorDim) - { - return Number{}; - } - else - { - return I1; - } - }, - Number{}); - - auto transfer = - ThreadwiseTensorSliceTransfer_v4r1, - remove_cvref_t, - decltype(thread_slice_lengths), - decltype(dim_access_order), - decltype(src_vector_tensor_lengths), - decltype(dim_access_order)>{ - src_tensor.GetMultiIdxOffsets()}; + auto transfer = ThreadwiseTensorSliceTransfer_v2< + std::remove_const_t, + std::remove_const_t, + remove_cvref_t, + remove_cvref_t, + decltype(thread_slice_lengths), + decltype(dim_access_order), + VectorDim, + ScalarPerVector, + I1, + false, + false>{in_grid_desc, src_tensor.GetMultiIdxOffsets()}; transfer.Run(in_grid_desc, - src_dst_slice_origin, src_tensor.GetBuffer(), out_grid_desc, - src_dst_slice_origin, + dst_slice_origin_idxs, dst_tensor.GetBuffer()); } else @@ -183,10 +171,12 @@ template -__device__ void blockwise_copy(const SrcTensorType& src_tensor, - DstTensorType& dst_tensor, - [[maybe_unused]] ThreadLayoutTuple& thread_layout) + typename ThreadShape, + typename ThreadUnrolledDesc> +__device__ void +blockwise_copy(const SrcTensorType& src_tensor, + DstTensorType& dst_tensor, + [[maybe_unused]] const Layout& thread_layout) { static_assert(SrcTensorType::IsDynamicBuffer && DstTensorType::IsDynamicBuffer); static_assert(is_detected::value); @@ -199,12 +189,12 @@ __device__ void blockwise_copy(const SrcTensorType& src_tensor, constexpr auto tile_lengths_seq = generate_sequence_v2([](auto I) { return size(SrcShapeType{}.At(I)); }, Number{}); - constexpr auto thread_layout_seq = generate_sequence_v2( - [](auto I) { return size(ThreadLayoutTuple{}.At(I)); }, Number{}); + constexpr auto thread_layout_seq = + generate_sequence_v2([](auto I) { return size(ThreadShape{}); }, Number{}); constexpr auto dim_access_order = generate_sequence_v2( [](auto I) { return DimAccessOrderTuple{}.At(I); }, Number{}); - using ThisThreadBlock = ThisThreadBlock; + using ThisThreadBlock = ThisThreadBlock; // Perform copy between DynamicBuffers auto transfer = ThreadGroupTensorSliceTransfer_v7< diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp index 9b8c0543fd..e41cd5bd8a 100644 --- a/include/ck/wrapper/operations/gemm.hpp +++ b/include/ck/wrapper/operations/gemm.hpp @@ -48,8 +48,9 @@ __device__ constexpr auto GetBlockDescriptor() /** * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be - * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) and B - * data layout must be (NPerBlock, KPerBlock). + * stored in Vgpr register. A data layout must be (MPerBlock, KPerBlock) or + * (K0PerBlock, MPerBlock, K1) and B data layout must be (NPerBlock, KPerBlock) + * or (K0PerBlock, NPerBlock, K1). * * \note C output Vgpr register layout (8D): * - MXdlPerWave - The number of MFMA instructions run by single wave in M @@ -71,9 +72,9 @@ __device__ constexpr auto GetBlockDescriptor() * \tparam BlockSize Tensor to pad. * \tparam GemmTraits Traits of gemm xdl operation. * \param a_local_tile_tensor A tensor in LDS memory for blockwise gemm - * (MPerBlock, KPerBlock) layout. + * (MPerBlock, KPerBlock) or (K0PerBlock, MPerBlock, K1) layout. * \param b_local_tile_tensor B tensor in LDS memory for blockwise gemm - * (NPerBlock, KPerBlock) layout. + * (NPerBlock, KPerBlock) or (K0PerBlock, NPerBlock, K1) layout. * \param c_reg_tensor C tensor VGPR memory for blockwise gemm. */ template {}; + static_assert(ATensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(BTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Lds); static_assert(CTensorType::TensorBufferAddressSpace == MemoryTypeEnum::Vgpr); @@ -99,10 +102,18 @@ __device__ void blockwise_gemm_xdl(const ATensorType& a_local_tile_tensor, using ATileLayout = remove_cvref_t; using BTileLayout = remove_cvref_t; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; + using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; constexpr auto I7 = Number<7>{}; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_integer = is_same_v || is_same_v || is_same_v; using GemmAccDataType = std::conditional_t; + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BlockwiseGemmXdlops = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; }, Number<8>{}); + + auto sliced_desc = transform_tensor_descriptor( + partition_desc, + make_tuple( + make_slice_transform(partition_shape.At(Number<0>{}), + m_thread_data_on_grid_idx[I0], + partition_shape.At(Number<0>{}) + m_thread_data_on_grid_idx[I0]), + make_slice_transform(partition_shape.At(Number<1>{}), + n_thread_data_on_grid_idx[I0], + partition_shape.At(Number<1>{}) + n_thread_data_on_grid_idx[I0]), + make_slice_transform(partition_shape.At(Number<2>{}), + m_thread_data_on_grid_idx[I1], + partition_shape.At(Number<2>{}) + m_thread_data_on_grid_idx[I1]), + make_slice_transform(partition_shape.At(Number<3>{}), + n_thread_data_on_grid_idx[I1], + partition_shape.At(Number<3>{}) + n_thread_data_on_grid_idx[I1]), + make_slice_transform(partition_shape.At(Number<4>{}), + m_thread_data_on_grid_idx[I2], + partition_shape.At(Number<4>{}) + m_thread_data_on_grid_idx[I2]), + make_slice_transform(partition_shape.At(Number<5>{}), + m_thread_data_on_grid_idx[I3], + partition_shape.At(Number<5>{}) + m_thread_data_on_grid_idx[I3]), + make_slice_transform(partition_shape.At(Number<6>{}), + m_thread_data_on_grid_idx[I4], + partition_shape.At(Number<6>{}) + m_thread_data_on_grid_idx[I4]), + make_slice_transform(partition_shape.At(Number<7>{}), + n_thread_data_on_grid_idx[I2], + partition_shape.At(Number<7>{}) + n_thread_data_on_grid_idx[I2])), + lower_upper_dims, + lower_upper_dims); + const auto partition_layout = - Layout, decltype(partition_desc)>( - partition_shape, partition_desc); + Layout, decltype(sliced_desc)>( + partition_shape, sliced_desc); auto partition_tensor = make_tensor( c_local_tile_tensor.GetPointer(), partition_layout); - partition_tensor.SetMultiIdxOffset(make_multi_index(m_thread_data_on_grid_idx[I0], - n_thread_data_on_grid_idx[I0], - m_thread_data_on_grid_idx[I1], - n_thread_data_on_grid_idx[I1], - m_thread_data_on_grid_idx[I2], - m_thread_data_on_grid_idx[I3], - m_thread_data_on_grid_idx[I4], - n_thread_data_on_grid_idx[I2])); return partition_tensor; } @@ -292,14 +337,22 @@ __host__ __device__ constexpr auto make_blockwise_gemm_xdl_c_vgpr() constexpr auto I6 = Number<6>{}; constexpr auto I7 = Number<7>{}; + static_assert(typename ATileLayout::LayoutShape{}.Size() == + typename BTileLayout::LayoutShape{}.Size()); + constexpr bool is_integer = is_same_v || is_same_v || is_same_v; using GemmAccDataType = std::conditional_t; + constexpr bool is_3d_desc = typename ATileLayout::LayoutShape{}.Size() == I3; using ABlockDesc_K0_M_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BBlockDesc_K0_N_K1_Type = - decltype(detail::GetBlockDescriptor()); + conditional_t())>; using BlockwiseGemmXdlops = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1, decltype(vgpr_desc)>( vgpr_shape, vgpr_desc); // Get vector type for Vgpr - using BlockwiseGemmCThreadBufferType = - remove_reference_t; - using VgprVectorType = typename BlockwiseGemmCThreadBufferType::V; + constexpr index_t ScalarPerVector = BlockwiseGemmXdlops::xdlops_gemm.GetRegSizePerXdlops(); + using VgprVectorType = typename vector_type::type; return ck::wrapper::make_register_tensor( vgpr_layout); } diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index e344399dbf..6946e79ea4 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -172,10 +172,10 @@ __host__ __device__ constexpr auto GenerateUpperDims(const Tuple& } } -template +template __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& idx, const Shape& shape, - const FlattenDescriptor& flatten_desc) + const UnrolledDescriptor& flatten_desc) { constexpr auto old_shape_dims = decltype(UnrollNestedTuple(shape))::Size(); diff --git a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp index 8301636a9f..54804dea3c 100644 --- a/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp +++ b/include/ck/wrapper/traits/blockwise_gemm_xdl_traits.hpp @@ -20,48 +20,57 @@ namespace wrapper { * \tparam K1Value The number of K-dim elements that are packed together as * a separate logical dimension. Usually aligns with vector load size. */ -template +template struct BlockwisGemmXdlTraits { - static constexpr index_t MPerXDL = MPerXDLValue; - static constexpr index_t NPerXDL = NPerXDLValue; - static constexpr index_t MXdlPerWave = MXdlPerWaveValue; - static constexpr index_t NXdlPerWave = NXdlPerWaveValue; - static constexpr index_t K1 = K1Value; + static constexpr auto MPerXDL = MPerXDLValue{}; + static constexpr auto NPerXDL = NPerXDLValue{}; + static constexpr auto MXdlPerWave = MXdlPerWaveValue{}; + static constexpr auto NXdlPerWave = NXdlPerWaveValue{}; + static constexpr auto K1 = K1Value{}; }; // K1 = 4 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<4>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<4>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 4> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_4K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<4>> { }; // K1 = 8 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<8>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<8>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 8> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_8K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<8>> { }; // K1 = 16 -struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 4, 2, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_4x2XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<4>, Number<2>, Number<16>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 4, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_2x4XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<4>, Number<16>> { }; -struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 : BlockwisGemmXdlTraits<32, 32, 2, 2, 16> +struct BlockwisGemmXdlTraits_32x32Xdl_2x2XdlPerWave_16K1 + : BlockwisGemmXdlTraits, Number<32>, Number<2>, Number<2>, Number<16>> { }; diff --git a/include/ck/wrapper/utils/kernel_utils.hpp b/include/ck/wrapper/utils/kernel_utils.hpp new file mode 100644 index 0000000000..add94ec6ae --- /dev/null +++ b/include/ck/wrapper/utils/kernel_utils.hpp @@ -0,0 +1,14 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { +namespace wrapper { + +#define __CK_WRAPPER_LAUNCH_BOUNDS__ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) + +} // namespace wrapper +} // namespace ck diff --git a/include/ck/wrapper/utils/layout_utils.hpp b/include/ck/wrapper/utils/layout_utils.hpp index d04bd5078b..e077fade2c 100644 --- a/include/ck/wrapper/utils/layout_utils.hpp +++ b/include/ck/wrapper/utils/layout_utils.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace wrapper { @@ -29,6 +30,7 @@ template using is_tuple = decltype(std::declval().IsTuple()); namespace { +namespace detail { /** * \brief Generate packed (column-major) strides if not passed * @@ -83,6 +85,7 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha return make_naive_tensor_descriptor(unrolled_shape, unrolled_strides); } } +} // namespace detail } // namespace /// @endcond @@ -98,8 +101,9 @@ __host__ __device__ constexpr auto MakeUnrolledDescriptor(const LayoutShape& sha template __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides& strides) { - using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Strides{})); - return Layout(shape, MakeUnrolledDescriptor(shape, strides)); + using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{})); + return Layout(shape, + detail::MakeUnrolledDescriptor(shape, strides)); } /** @@ -112,13 +116,12 @@ __host__ __device__ constexpr auto make_layout(const Shape& shape, const Strides template __host__ __device__ constexpr auto make_layout(const Shape& shape) { - using UnrolledDescriptorType = decltype(MakeUnrolledDescriptor(Shape{}, Tuple<>{})); - return Layout(shape, MakeUnrolledDescriptor(shape, Tuple<>{})); + using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{})); + return Layout(shape, + detail::MakeUnrolledDescriptor(shape, Tuple<>{})); } - // Layout helpers // get - /** * \private * \brief Get dim. @@ -152,8 +155,8 @@ __host__ __device__ constexpr auto get(const Tuple& tuple) * \param layout Layout to create sub layout. * \return Requsted sub layout. */ -template -__host__ __device__ constexpr auto get(const Layout& layout) +template +__host__ __device__ constexpr auto get(const Layout& layout) { const auto& shape = layout.GetShape(); const auto new_shape = get(shape); @@ -427,5 +430,91 @@ __host__ __device__ constexpr const auto& shape(const LayoutType& layout) return layout.GetShape(); } +// pad +/** + * \brief Pad layout shapes to be adjusted to tile lengths. + * + * + * \param layout Layout to pad. + * \param tile_lengths Tile lengths to align layout shape. + * \return Padded layout. + */ +template +__host__ __device__ constexpr auto pad(const Layout& layout, + const TileLengths& tile_lengths) +{ + auto& unrolled_desc = layout.GetUnrolledDescriptor(); + // Generate sequence with ones to mark that all dims will be padded + constexpr auto do_pads_seq = + generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); + // Create descriptor with padding + auto padded_desc = + tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); + // Generate padded shape + const auto padded_shape = generate_tuple( + [&](auto i) { return padded_desc.GetLength(Number{}); }, Number{}); + // Create layout + return Layout(padded_shape, padded_desc); +} + +// unmerge +/** + * \brief Unmerge selected dim in layout. + * + * \tparam Idx Index to dimension being unmerged. + * \param layout Layout to pad. + * \param new_lengths Dimensions into which the indicated dimension will be divided. + * \param new_indexes Indexes to shuffle dims. Dims for unmerged dim should be nested. + * \return Unmerged layout. + */ +template +__host__ __device__ constexpr auto unmerge(const Layout& layout, + const NewLengths& new_lengths, + [[maybe_unused]] const NewIdxs& new_indexes) +{ + const auto& layout_shape = shape(layout); + auto& unrolled_desc = layout.GetUnrolledDescriptor(); + constexpr auto dims = Shape::Size(); + // Generate transforms + const auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == Idx) + { + return make_unmerge_transform(new_lengths); + } + else + { + return make_pass_through_transform(layout_shape.At(i)); + } + }, + Number{}); + + constexpr auto lower_dims = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + constexpr auto upper_dims = generate_tuple( + [&](auto i) { + if constexpr(is_detected>::value) + { + constexpr auto idxs_tuple = tuple_element_t{}; + return to_sequence(idxs_tuple); + } + else + { + constexpr index_t index = tuple_element_t{}; + return Sequence{}; + } + }, + Number{}); + + const auto unmerged_desc = + transform_tensor_descriptor(unrolled_desc, transforms, lower_dims, upper_dims); + const auto unmerged_shape = + generate_tuple([&](auto i) { return unmerged_desc.GetLength(Number{}); }, + Number{}); + + // Create layout + return Layout(unmerged_shape, unmerged_desc); +} + } // namespace wrapper } // namespace ck diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 5638382dba..141e0a58e5 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -6,7 +6,6 @@ #include "tensor_utils.hpp" #include "layout_utils.hpp" -#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" @@ -44,8 +43,9 @@ __host__ __device__ constexpr auto CalculateLocalPartitionShape(const Tuple{} to keep. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Multi index after projection. */ template @@ -73,7 +73,7 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, } else { - return base_tuple.At(i_num); + return make_tuple(base_tuple.At(i_num)); } }, Number{}); @@ -86,8 +86,9 @@ ApplyProjection([[maybe_unused]] const MultiIndex& base_tuple, * \brief Calculate shape with dims from projection. * * \param shape Base tensor shape. - * \param projection Projection to remove selected dim from partitioning. - * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Shape with dims from projection */ template @@ -119,22 +120,14 @@ __host__ __device__ constexpr auto CalculateShapeWithProjection(const Tuple{}` to keep it. * \return Tuple with blocks number. */ template __host__ __device__ constexpr auto CalculateGridSize(const Tuple& shape, - const Tuple& tile_shape, - const Tuple& projection) + const Tuple& tile_shape) { - auto shape_with_projection = CalculateShapeWithProjection(shape, projection); return generate_tuple( - [&](auto i) { - return ck::math::integer_divide_ceil(size(shape_with_projection), - size(tile_shape)); - }, + [&](auto i) { return ck::math::integer_divide_ceil(size(shape), size(tile_shape)); }, Number::Size()>{}); } @@ -155,6 +148,54 @@ CalculateOffsetMultiIdxs(const ThreadIdxs& thread_idxs, return thread_idxs * partition_lengths_seq + old_offset_idxs; } +/** + * \brief Select dims to partition (skip if slice). + * + * \param block_idxs Input block indexes. + * \return Partitioned dims. + */ +template +__host__ __device__ constexpr auto GetDimsToPartition([[maybe_unused]] const BlockIdxs& block_idxs) +{ + const auto dims_to_partition = generate_tuple( + [&](auto i) { + if constexpr(!is_detected>::value) + { + return Number{}; + } + else + { + return Tuple<>{}; + } + }, + Number{}); + // Remove empty tuples + return UnrollNestedTuple<0, 1>(dims_to_partition); +} + +/** + * \brief Replace slices with zeros (Slice dims are not partitioned). + * + * \param block_idxs Input block indexes. + * \return Parsed dims. + */ +template +__host__ __device__ constexpr auto ReplaceSlicesWithZeros(const BlockIdxs& block_idxs) +{ + return generate_tuple( + [&](auto i) { + if constexpr(!is_detected>::value) + { + return block_idxs.At(i); + } + else + { + return Number<0>{}; + } + }, + Number{}); +} + /** * \brief Calculate default projection. * @@ -168,6 +209,31 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) return generate_tuple([&](auto) { return Number<1>{}; }, Number{}); } +/** + * \brief Calculate thread multi index from 1d thread index. + * + * \param thread_layout Layout of threads (could not be nested). + * \param thread_id Thread index represented as integer. + * \return Multi index. + */ +template +__host__ __device__ constexpr auto CalculateThreadMultiIdx( + [[maybe_unused]] const Layout& thread_layout, + const index_t thread_id) +{ + static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1, + "Thread layout should not be transformed."); + constexpr auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{}); + constexpr auto shape = ThreadShape{}; + constexpr auto strides = embed_transform.coefficients_; + + return generate_tuple( + [&](auto i) { + constexpr auto num_i = Number{}; + return (thread_id / strides.At(num_i)) % shape.At(num_i); + }, + Number{}); +} } // namespace detail } // namespace @@ -176,51 +242,62 @@ GenerateDefaultProjection([[maybe_unused]] const TileShape tile_shape) * is supported). * * \param tensor Tensor for partition. - * \param thread_lengths Layout of threads (could not be nested). + * \param thread_layout Layout of threads (could not be transformed). * \param thread_id Thread index represented as integer. * \param projection Projection is used to remove selected dim from * partitioning. Use `slice(X)` to remove dimension, where X is dim * size. Use `Number<1>{}` to keep it. * \return Partition tensor. */ -template +template __host__ __device__ constexpr auto make_local_partition(TensorType& tensor, - [[maybe_unused]] const ThreadLengthsTuple& thread_lengths, + [[maybe_unused]] const Layout& thread_layout, const index_t thread_id, const ProjectionTuple& projection) { - static_assert(!IsNestedTuple(ThreadLengthsTuple{})); + static_assert(!IsNestedTuple(ThreadShape{})); // Calculate new partition shape const auto& tensor_shape = shape(tensor); // Calculate projected thread lengths constexpr auto projected_thread_lengths = - detail::ApplyProjection(ThreadLengthsTuple{}, ProjectionTuple{}); + detail::ApplyProjection(ThreadShape{}, ProjectionTuple{}); constexpr auto partition_shape = detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths); - // Create Thread Cluster Descriptor constexpr auto partition_shape_seq = generate_sequence_v2([&](auto I) { return size(partition_shape); }, Number{}); - constexpr auto thread_lengths_seq = - generate_sequence_v2([&](auto I) { return size(ThreadLengthsTuple{}); }, - Number{}); - constexpr auto thread_cluster_desc_ = make_cluster_descriptor(thread_lengths_seq); // Calculate thread idxs and offsets - const auto thread_idxs = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(thread_id)); + const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id); // Apply projection on thread idxs to remove not needed idxs const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection); const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); + // Slice descriptor + const auto transforms = generate_tuple( + [&](auto i) { + return make_slice_transform(partition_shape.At(i), + offset_multi_idxs.At(i), + partition_shape.At(i) + offset_multi_idxs.At(i)); + }, + Number::Size()>{}); + const auto lower_upper_dims = + generate_tuple([&](auto i) { return Sequence{}; }, + Number::Size()>{}); + auto sliced_desc = + transform_tensor_descriptor(unrolled_desc, transforms, lower_upper_dims, lower_upper_dims); + // Create layout const auto partition_layout = - Layout, decltype(unrolled_desc)>( - partition_shape, unrolled_desc); + Layout, decltype(sliced_desc)>( + partition_shape, sliced_desc); auto partition_tensor = make_tensor(tensor.GetPointer(), partition_layout); // Apply offsets - partition_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); return partition_tensor; } @@ -233,12 +310,13 @@ make_local_partition(TensorType& tensor, * \param thread_id Thread index represented as integer. * \return Partition tensor. */ -template -__host__ __device__ constexpr auto make_local_partition(TensorType& tensor, - const ThreadLengthsTuple& thread_lengths, - const index_t thread_id) +template +__host__ __device__ constexpr auto +make_local_partition(TensorType& tensor, + const Layout& thread_lengths, + const index_t thread_id) { - const auto projection = detail::GenerateDefaultProjection(ThreadLengthsTuple{}); + const auto projection = detail::GenerateDefaultProjection(ThreadShape{}); return make_local_partition(tensor, thread_lengths, thread_id, projection); } @@ -252,21 +330,24 @@ __host__ __device__ constexpr auto make_local_partition(TensorType& tensor, * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_id Block index represented as integer. - * \param projection Projection to remove selected dim from partitioning. - * slice(X) to remove, where X is dim size, Number<1>{} to keep. + * \param block_idxs Tuple of block indexes represented as integer. If slice, + * then get whole dim. + * \param projection Projection is used to remove selected dim from + * partitioning. Use `slice(X)` to remove dimension, where X is dim + * size. Use `Number<1>{}` to keep it. * \return Tile tensor. */ -template +template __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, - const index_t block_id, + const BlockIdxs& block_idxs, const ProjectionTuple& projection) { static_assert(!IsNestedTuple(BlockShapeTuple{})); - - constexpr bool is_default_projection = - is_same_v; + static_assert(!IsNestedTuple(BlockIdxs{})); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -274,49 +355,77 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, auto& aligned_desc = layout(tensor).GetMergedNestingDescriptor(); - // TODO: Enable block_2_tile_map partitioning for non-default projection. - if constexpr(BlockShapeTuple::Size() == I2 && is_default_projection) + constexpr auto projected_tile_shape = + detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); + // Number of dims which are partitioned + constexpr auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{}); + const auto parsed_block_idxs = detail::ReplaceSlicesWithZeros(block_idxs); + if constexpr(decltype(dims_to_partition)::Size() == I2) { - // Optimized version for 2d tile shape [MxK] + const auto shape_with_projection_dims = + detail::CalculateShapeWithProjection(shape(tensor), projection); + // Set Value for M, N partition + const auto M = shape_with_projection_dims.At(dims_to_partition.At(I0)); + const auto N = shape_with_projection_dims.At(dims_to_partition.At(I1)); + constexpr auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0)); + constexpr auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1)); + auto m_n_desc = make_naive_tensor_descriptor_packed(make_tuple(M, N)); + // Get 1D block id + const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape); + const auto block_lengths_desc = make_naive_tensor_descriptor_packed(grid_size); + const index_t block_id_1d = block_lengths_desc.CalculateOffset(parsed_block_idxs); + // Optimized version for 2d tile shape [MxN] const auto block_2_tile_map = - BlockToCTileMap_M00_N0_M01Adapt>(aligned_desc); + BlockToCTileMap_M00_N0_M01Adapt>(m_n_desc); const auto block_work_idx = - block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id_1d)); const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * size<0>(tile_shape)); - const index_t k_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * size<1>(tile_shape)); - const auto offset_multi_idxs = - make_tuple(m_block_data_idx_on_grid, k_block_data_idx_on_grid); + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + // Apply 0 for non partitioned dims + const auto offset_multi_idxs = generate_tuple( + [&](auto i) { + if constexpr(i == dims_to_partition.At(I0)) + { + return m_block_data_idx_on_grid; + } + else if constexpr(i == dims_to_partition.At(I1)) + { + return n_block_data_idx_on_grid; + } + else + { + return Number<0>{}; + } + }, + Number{}); + const auto projected_offset_multi_idxs = + detail::ApplyProjection(offset_multi_idxs, projection); // Create new layout and tensor const auto tile_layout = - Layout, decltype(aligned_desc)>(tile_shape, - aligned_desc); + Layout, decltype(aligned_desc)>( + projected_tile_shape, aligned_desc); auto tile_tensor = make_tensor(tensor.GetPointer(), tile_layout); // Apply offsets - tile_tensor.SetMultiIdxOffset(to_multi_index(offset_multi_idxs)); + tile_tensor.SetMultiIdxOffset(to_multi_index(projected_offset_multi_idxs)); return tile_tensor; } else { // Calculate offsets // Sequence with data to process per block - constexpr auto projected_tile_shape = - detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{}); using ProjectedTileShapeTuple = decltype(projected_tile_shape); constexpr auto projected_tile_shape_seq = generate_sequence_v2([](auto I) { return ProjectedTileShapeTuple{}.At(I); }, Number{}); // Tuple with number of blocks - const auto block_lengths = detail::CalculateGridSize(shape(tensor), tile_shape, projection); - const auto block_cluster_desc_ = make_cluster_descriptor(block_lengths); - const auto block_idxs = - block_cluster_desc_.CalculateBottomIndex(make_multi_index(block_id)); - const auto projected_block_idxs = detail::ApplyProjection(block_idxs, projection); - const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( + const auto projected_block_idxs = + to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection)); + const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs( projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets()); // Create new layout and tensor const auto tile_layout = @@ -338,52 +447,17 @@ __host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, * * \param tensor Tensor for partition. * \param tile_shape Shapes of requested tile. - * \param block_id Block index represented as integer. + * \param block_idxs Tuple of block indexes represented as integer. If slice, + * then get whole dim. * \return Tile tensor. */ -template -__host__ __device__ constexpr auto -make_local_tile(const TensorType& tensor, const BlockShapeTuple& tile_shape, const index_t block_id) +template +__host__ __device__ constexpr auto make_local_tile(const TensorType& tensor, + const BlockShapeTuple& tile_shape, + const BlockIdxs& block_idxs) { const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{}); - return make_local_tile(tensor, tile_shape, block_id, projection); -} - -/** - * \brief Pad tensor shapes to be adjusted to tile lengths. - * - * - * \param tensor Tensor to pad. - * \param tile_lengths Tile lengths to align tensor shape. - * \return Padded tensor. - */ -template -__host__ __device__ constexpr auto pad(const TensorType& tensor, const TileLengths& tile_lengths) -{ - const auto& tensor_shape = shape(tensor); - using TensorShapeType = remove_reference_t; - auto& unrolled_desc = layout(tensor).GetUnrolledDescriptor(); - // Generate sequence with ones to mark that all dims will be padded - constexpr auto do_pads_seq = - generate_sequence_v2([](auto) { return Number<1>{}; }, Number{}); - // Create descriptor with padding - auto padded_desc = - tensor_operation::device::PadTensorDescriptor(unrolled_desc, tile_lengths, do_pads_seq); - // Generate padded shape - const auto padded_shape = generate_tuple( - [&](auto i) { - const auto& dim = size(tensor_shape); - const auto& tile_length = size(tile_lengths); - return ck::math::integer_divide_ceil(dim, tile_length) * tile_length; - }, - Number{}); - // Create layout and tensor - const auto padded_layout = - Layout(padded_shape, padded_desc); - auto partition_tensor = - make_tensor(tensor.GetPointer(), padded_layout); - partition_tensor.SetMultiIdxOffset(tensor.GetMultiIdxOffsets()); - return partition_tensor; + return make_local_tile(tensor, tile_shape, block_idxs, projection); } } // namespace wrapper diff --git a/test/wrapper/CMakeLists.txt b/test/wrapper/CMakeLists.txt index cadc146795..383707828c 100644 --- a/test/wrapper/CMakeLists.txt +++ b/test/wrapper/CMakeLists.txt @@ -1,14 +1,21 @@ -add_gtest_executable(test_layout test_layout.cpp) -target_link_libraries(test_layout PRIVATE utility) -add_gtest_executable(test_tensor test_tensor.cpp) -target_link_libraries(test_tensor PRIVATE utility) -add_gtest_executable(test_copy test_copy.cpp) -target_link_libraries(test_copy PRIVATE utility) -add_gtest_executable(test_partition test_partition.cpp) -target_link_libraries(test_partition PRIVATE utility) +add_custom_target(test_wrapper) + +add_gtest_executable(test_wrapper_layout test_wrapper_layout.cpp) +target_link_libraries(test_wrapper_layout PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_layout) +add_gtest_executable(test_wrapper_tensor test_wrapper_tensor.cpp) +target_link_libraries(test_wrapper_tensor PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_tensor) +add_gtest_executable(test_wrapper_copy test_wrapper_copy.cpp) +target_link_libraries(test_wrapper_copy PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_copy) +add_gtest_executable(test_wrapper_partition test_wrapper_partition.cpp) +target_link_libraries(test_wrapper_partition PRIVATE utility) +add_dependencies(test_wrapper test_wrapper_partition) if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") - add_gtest_executable(test_gemm test_gemm.cpp) - target_link_libraries(test_gemm PRIVATE utility) + add_gtest_executable(test_wrapper_gemm test_wrapper_gemm.cpp) + target_link_libraries(test_wrapper_gemm PRIVATE utility) + add_dependencies(test_wrapper test_wrapper_gemm) endif() diff --git a/test/wrapper/test_gemm.cpp b/test/wrapper/test_gemm.cpp deleted file mode 100644 index 12245490d1..0000000000 --- a/test/wrapper/test_gemm.cpp +++ /dev/null @@ -1,257 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include - -#include "ck/library/utility/host_tensor.hpp" - -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" - -#include "ck/host_utility/kernel_launch.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/wrapper/layout.hpp" -#include "ck/wrapper/tensor.hpp" -#include "ck/wrapper/operations/copy.hpp" -#include "ck/wrapper/operations/gemm.hpp" - -template -void CheckResult(const std::vector& a_data, - const std::vector& b_data, - std::vector& c_m_n_device_result, - const ck::index_t M, - const ck::index_t N, - const ck::index_t K) -{ - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; - - Tensor a_m_k(HostTensorDescriptor({M, K})); - Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); - Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); - - a_m_k.mData = a_data; - b_k_n.mData = b_data; - - auto ref_op = ReferenceGemmInstance{}; - auto ref_invoker = ref_op.MakeInvoker(); - auto ref_argument = ref_op.MakeArgument( - a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); - - ref_invoker.Run(ref_argument); - EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); -} - -template -__global__ void DeviceGemm(const void* p_a, - const void* p_b, - void* p_c, - const ck::index_t M, - const ck::index_t N, - const ck::index_t K, - const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) -{ - constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); - constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); - constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); - - const auto a_global_layout = - ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); - const auto b_global_layout = - ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); - const auto c_global_layout = - ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); - - constexpr auto a_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(MPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); - constexpr auto b_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(NPerBlock, KPerBlock), ck::make_tuple(KPerBlock, ck::Number<1>{})); - constexpr auto c_tile_layout = ck::wrapper::make_layout( - ck::make_tuple(MPerBlock, NPerBlock), ck::make_tuple(NPerBlock, ck::Number<1>{})); - - auto a_global_tensor = ck::wrapper::make_tensor( - static_cast(p_a), a_global_layout); - auto b_global_tensor = ck::wrapper::make_tensor( - static_cast(p_b), b_global_layout); - auto c_global_tensor = ck::wrapper::make_tensor( - static_cast(p_c), c_global_layout); - - auto a_padded_global_tensor = ck::wrapper::pad(a_global_tensor, shape(a_tile_layout)); - auto b_padded_global_tensor = ck::wrapper::pad(b_global_tensor, shape(b_tile_layout)); - auto c_padded_global_tensor = ck::wrapper::pad(c_global_tensor, shape(c_tile_layout)); - - __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout)]; - __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout)]; - - auto a_lds_tensor = ck::wrapper::make_tensor( - static_cast(lds_a), a_tile_layout); - auto b_lds_tensor = ck::wrapper::make_tensor( - static_cast(lds_b), b_tile_layout); - - const ck::index_t block_idx = static_cast(blockIdx.x); - using DimAccessOrder = ck::Tuple, ck::Number<1>>; - constexpr ck::index_t vector_dim = 1; - - auto c_global_local_tile = ck::wrapper::make_local_tile( - c_padded_global_tensor, - tile_shape, - block_idx, - make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(KPerBlock))); - auto c_global_local_partition = - ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); - auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); - ck::wrapper::clear(c_vgpr_reg); - - const ck::index_t num_loop = ck::math::integer_divide_ceil(K, KPerBlock); - ck::index_t i = 0; - do - { - const auto k_slice = ck::wrapper::slice(i * KPerBlock, (i + 1) * KPerBlock); - auto a_padded_global_tensor_k_slice = a_padded_global_tensor(ck::wrapper::slice(), k_slice); - auto b_padded_global_tensor_k_slice = b_padded_global_tensor(ck::wrapper::slice(), k_slice); - auto a_global_local_tile = ck::wrapper::make_local_tile( - a_padded_global_tensor_k_slice, - tile_shape, - block_idx, - make_tuple(ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{})); - auto b_global_local_tile = ck::wrapper::make_local_tile( - b_padded_global_tensor_k_slice, - tile_shape, - block_idx, - make_tuple(ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{})); - - ck::wrapper::blockwise_copy( - a_global_local_tile, a_lds_tensor, thread_layout); - ck::wrapper::blockwise_copy( - b_global_local_tile, b_lds_tensor, thread_layout); - ck::block_sync_lds(); - ck::wrapper::blockwise_gemm_xdl( - a_lds_tensor, b_lds_tensor, c_vgpr_reg); - - ++i; - } while(i < num_loop); - - ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); -} - -template -void PerformGemm(const ck::index_t M, - const ck::index_t N, - const ck::index_t K, - const BlockShape& tile_shape, - const ThreadLayoutShape& thread_layout) -{ - // Global memory buffers - DeviceMem a_mem(M * K * sizeof(DataType)); - DeviceMem b_mem(K * N * sizeof(DataType)); - DeviceMem c_mem(M * N * sizeof(DataType)); - - std::vector a_data(M * K); - std::vector b_data(K * N); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); - ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); - - a_mem.ToDevice(a_data.data()); - b_mem.ToDevice(b_data.data()); - c_mem.SetZero(); - - const ck::index_t grid_size = - ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)) * - ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); - - const auto kernel = - DeviceGemm; - launch_and_time_kernel(StreamConfig{nullptr}, - kernel, - dim3(grid_size), - dim3(ck::wrapper::size(thread_layout)), - 0, - a_mem.GetDeviceBuffer(), - b_mem.GetDeviceBuffer(), - c_mem.GetDeviceBuffer(), - M, - N, - K, - tile_shape, - thread_layout); - - std::vector c_data(M * N); - c_mem.FromDevice(c_data.data()); - - CheckResult(a_data, b_data, c_data, M, N, K); -} - -TEST(TestGemm, Float) -{ - using DataType = float; - const auto thread_layout = ck::make_tuple(ck::Number<16>{}, ck::Number<16>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Int8) -{ - using DataType = int8_t; - const auto thread_layout = ck::make_tuple(ck::Number<64>{}, ck::Number<4>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Half) -{ - using DataType = ck::half_t; - const auto thread_layout = ck::make_tuple(ck::Number<32>{}, ck::Number<8>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout); - // Irregular case - PerformGemm( - 129, 129, 67, tile_shape, thread_layout); -} - -TEST(TestGemm, Float_2x4_4x2_XdlPerWave) -{ - using DataType = float; - const auto thread_layout_4x2_xdl_per_wave = ck::make_tuple(ck::Number<16>{}, ck::Number<8>{}); - const auto thread_layout_2x4_xdl_per_wave = ck::make_tuple(ck::Number<8>{}, ck::Number<16>{}); - const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout_4x2_xdl_per_wave); - PerformGemm( - 512, 512, 128, tile_shape, thread_layout_2x4_xdl_per_wave); -} diff --git a/test/wrapper/test_copy.cpp b/test/wrapper/test_wrapper_copy.cpp similarity index 83% rename from test/wrapper/test_copy.cpp rename to test/wrapper/test_wrapper_copy.cpp index e7fa3c539b..4721006435 100644 --- a/test/wrapper/test_copy.cpp +++ b/test/wrapper/test_wrapper_copy.cpp @@ -20,23 +20,25 @@ template __global__ void TestCopyDevice(const InputTensor input_tensor, OutputTensor output_tensor, const BlockShape tile_shape, - const ThreadLayoutShape thread_layout) + const ThreadLayout thread_layout) { __shared__ ck::index_t p_shared[ck::wrapper::size(tile_shape)]; const auto tensor_lds = ck::wrapper::make_tensor( p_shared, ck::wrapper::make_layout(tile_shape)); - const auto block_idx = static_cast(blockIdx.x); + const auto block_idxs = + ck::make_tuple(static_cast(blockIdx.x), static_cast(blockIdx.y)); // Get local tiles for global memory - const auto input_local_tile = ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idx); + const auto input_local_tile = + ck::wrapper::make_local_tile(input_tensor, tile_shape, block_idxs); const auto output_local_tile = - ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idx); + ck::wrapper::make_local_tile(output_tensor, tile_shape, block_idxs); // Get partition per thread const auto input_local_partition = @@ -49,7 +51,7 @@ __global__ void TestCopyDevice(const InputTensor input_tensor, // Allocate VGPR auto tensor_vgpr = ck::wrapper::make_register_tensor( - layout(lds_local_partition)); + ck::wrapper::make_layout(shape(lds_local_partition))); // Perform copy if constexpr(UseOptimizedCopy) @@ -99,11 +101,14 @@ void PerformCopyGlobalToGlobalViaLDS() auto output_tensor_global = ck::wrapper::make_tensor( static_cast(out_buf.GetDeviceBuffer()), layout); - const auto thread_layout = ck::make_tuple(ck::Number<1>{}, ck::Number<32>{}); - const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<1>{}, ck::Number<32>{})); + const auto tile_shape = ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}); - const ck::index_t grid_size = ck::math::integer_divide_ceil( - ck::wrapper::size(input_tensor_global), ck::wrapper::size(tile_shape)); + const ck::index_t grid_size_x = ck::math::integer_divide_ceil( + ck::wrapper::size<0>(input_tensor_global), ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = ck::math::integer_divide_ceil( + ck::wrapper::size<1>(input_tensor_global), ck::wrapper::size<1>(tile_shape)); const auto kernel = TestCopyDevice; launch_and_time_kernel(StreamConfig{}, kernel, - dim3(grid_size), + dim3(grid_size_x, grid_size_y, 1), dim3(ck::wrapper::size(thread_layout)), 0, input_tensor_global, diff --git a/test/wrapper/test_wrapper_gemm.cpp b/test/wrapper/test_wrapper_gemm.cpp new file mode 100644 index 0000000000..fd2cb7d4f3 --- /dev/null +++ b/test/wrapper/test_wrapper_gemm.cpp @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include + +#include "ck/library/utility/host_tensor.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/wrapper/layout.hpp" +#include "ck/wrapper/tensor.hpp" +#include "ck/wrapper/operations/copy.hpp" +#include "ck/wrapper/operations/gemm.hpp" +#include "ck/wrapper/utils/kernel_utils.hpp" + +template +void CheckResult(const std::vector& a_data, + const std::vector& b_data, + std::vector& c_m_n_device_result, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + Tensor a_m_k(HostTensorDescriptor({M, K})); + Tensor b_k_n(HostTensorDescriptor({K, N}, {1, K})); + Tensor c_m_n_host_result(HostTensorDescriptor({M, N})); + + a_m_k.mData = a_data; + b_k_n.mData = b_data; + + auto ref_op = ReferenceGemmInstance{}; + auto ref_invoker = ref_op.MakeInvoker(); + auto ref_argument = ref_op.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + EXPECT_TRUE(ck::utils::check_err(c_m_n_device_result, c_m_n_host_result.mData)); +} + +template +__device__ auto ApplyPadding(const Layout& layout, const PaddingDims& padding_dims) +{ + if constexpr(DoPad) + { + return ck::wrapper::pad(layout, padding_dims); + } + else + { + return layout; + } +} + +template +__global__ void __CK_WRAPPER_LAUNCH_BOUNDS__ DeviceGemm(const void* p_a, + const void* p_b, + void* p_c, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape tile_shape, + const ThreadLayout thread_layout) +{ + constexpr auto MPerBlock = ck::wrapper::size<0>(tile_shape); + constexpr auto NPerBlock = ck::wrapper::size<1>(tile_shape); + constexpr auto KPerBlock = ck::wrapper::size<2>(tile_shape); + constexpr auto K1 = GemmTraits::K1; + constexpr auto K0PerBlock = KPerBlock / K1; + const auto K0 = ck::math::integer_divide_ceil(K, K1); + + const auto tile_shape_k0_m_n_k1 = ck::make_tuple(K0PerBlock, MPerBlock, NPerBlock, K1); + + const auto a_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, K), ck::make_tuple(K, 1)); + const auto b_global_layout = + ck::wrapper::make_layout(ck::make_tuple(N, K), ck::make_tuple(K, 1)); + const auto c_global_layout = + ck::wrapper::make_layout(ck::make_tuple(M, N), ck::make_tuple(N, 1)); + + auto a_padded_global_layout = + ApplyPadding(a_global_layout, ck::make_tuple(MPerBlock, KPerBlock)); + auto b_padded_global_layout = + ApplyPadding(b_global_layout, ck::make_tuple(NPerBlock, KPerBlock)); + auto c_padded_global_layout = + ApplyPadding(c_global_layout, ck::make_tuple(MPerBlock, NPerBlock)); + + // Reshape from M,K to K0,M,K1 + const auto reshaped_dims_idxs = + ck::make_tuple(ck::Number<1>{}, ck::make_tuple(ck::Number<0>{}, ck::Number<2>{})); + auto a_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(a_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + auto b_padded_unmerged_global_layout = + ck::wrapper::unmerge<1>(b_padded_global_layout, ck::make_tuple(K0, K1), reshaped_dims_idxs); + + auto a_global_tensor = ck::wrapper::make_tensor( + static_cast(p_a), a_padded_unmerged_global_layout); + auto b_global_tensor = ck::wrapper::make_tensor( + static_cast(p_b), b_padded_unmerged_global_layout); + auto c_global_tensor = ck::wrapper::make_tensor( + static_cast(p_c), c_padded_global_layout); + + // Add extra M and N + constexpr auto a_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, MPerBlock, K1), + ck::make_tuple((MPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + constexpr auto b_tile_layout = ck::wrapper::make_layout( + ck::make_tuple(K0PerBlock, NPerBlock, K1), + ck::make_tuple((NPerBlock + ck::Number<1>{}) * K1, K1, ck::Number<1>{})); + + __shared__ DataType lds_a[ck::wrapper::size(a_tile_layout) + NPerBlock]; + __shared__ DataType lds_b[ck::wrapper::size(b_tile_layout) + NPerBlock]; + + auto a_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_a), a_tile_layout); + auto b_lds_tensor = ck::wrapper::make_tensor( + static_cast(lds_b), b_tile_layout); + + const auto block_idxs = ck::make_tuple(ck::wrapper::slice(), + static_cast(blockIdx.x), + static_cast(blockIdx.y), + ck::wrapper::slice()); + using DimAccessOrder = ck::Tuple, ck::Number<0>, ck::Number<2>>; + constexpr ck::index_t vector_dim = 2; + + auto c_global_local_tile = + ck::wrapper::make_local_tile(c_global_tensor, + tile_shape_k0_m_n_k1, + block_idxs, + make_tuple(ck::wrapper::slice(K0PerBlock), + ck::Number<1>{}, + ck::Number<1>{}, + ck::wrapper::slice(K1))); + auto c_global_local_partition = + ck::wrapper::make_blockwise_gemm_xdl_c_local_partition(c_global_local_tile); + auto c_vgpr_reg = ck::wrapper::make_blockwise_gemm_xdl_c_vgpr(); + ck::wrapper::clear(c_vgpr_reg); + + auto a_lds_tensor_local_partition = + ck::wrapper::make_local_partition(a_lds_tensor, thread_layout, threadIdx.x); + auto b_lds_tensor_local_partition = + ck::wrapper::make_local_partition(b_lds_tensor, thread_layout, threadIdx.x); + + auto make_global_partition = [&](auto tensor, auto projection, ck::index_t i) { + const auto k_slice = + ck::make_tuple(ck::wrapper::slice(i * K0PerBlock, (i + 1) * K0PerBlock), + ck::wrapper::slice(), + ck::wrapper::slice()); + auto local_tile = ck::wrapper::make_local_tile( + tensor(k_slice), tile_shape_k0_m_n_k1, block_idxs, projection); + return ck::wrapper::make_local_partition(local_tile, thread_layout, threadIdx.x); + }; + + auto a_global_local_partition = make_global_partition( + a_global_tensor, + make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + 0); + auto b_global_local_partition = make_global_partition( + b_global_tensor, + make_tuple(ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + 0); + + // (row-major vgpr layout) + auto a_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(a_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + auto b_vgpr_tensor = + ck::wrapper::make_register_tensor( + ck::wrapper::make_layout( + shape(b_global_local_partition), + ck::make_tuple(ck::wrapper::size<1>(a_global_local_partition) * + ck::wrapper::size<2>(a_global_local_partition), + ck::wrapper::size<2>(a_global_local_partition), + ck::Number<1>{}))); + + ck::wrapper::copy(a_global_local_partition, + a_vgpr_tensor); + ck::wrapper::copy(b_global_local_partition, + b_vgpr_tensor); + ck::wrapper::copy(a_vgpr_tensor, + a_lds_tensor_local_partition); + ck::wrapper::copy(b_vgpr_tensor, + b_lds_tensor_local_partition); + + const ck::index_t num_loop = + __builtin_amdgcn_readfirstlane(ck::math::integer_divide_ceil(K, KPerBlock)); + if(num_loop > 1) + { + ck::index_t i = 0; + do + { + auto a_global_local_partition_i = make_global_partition( + a_global_tensor, + make_tuple( + ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(N), ck::Number<1>{}), + i + 1); + auto b_global_local_partition_i = make_global_partition( + b_global_tensor, + make_tuple( + ck::Number<1>{}, ck::wrapper::slice(M), ck::Number<1>{}, ck::Number<1>{}), + i + 1); + + ck::wrapper::copy( + a_global_local_partition_i, a_vgpr_tensor); + + ck::block_sync_lds(); + ck::wrapper::copy( + b_global_local_partition_i, b_vgpr_tensor); + + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ck::block_sync_lds(); + ck::wrapper::copy( + a_vgpr_tensor, a_lds_tensor_local_partition); + ck::wrapper::copy( + b_vgpr_tensor, b_lds_tensor_local_partition); + + ++i; + } while(i < (num_loop - 1)); + } + ck::block_sync_lds(); + ck::wrapper::blockwise_gemm_xdl( + a_lds_tensor, b_lds_tensor, c_vgpr_reg); + + ck::wrapper::copy(c_vgpr_reg, c_global_local_partition); +} + +template +void PerformGemm(const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const BlockShape& tile_shape, + const ThreadLayout& thread_layout) +{ + // Global memory buffers + DeviceMem a_mem(M * K * sizeof(DataType)); + DeviceMem b_mem(K * N * sizeof(DataType)); + DeviceMem c_mem(M * N * sizeof(DataType)); + + std::vector a_data(M * K); + std::vector b_data(K * N); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_data); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_data); + + a_mem.ToDevice(a_data.data()); + b_mem.ToDevice(b_data.data()); + c_mem.SetZero(); + + const ck::index_t grid_size_x = + ck::math::integer_divide_ceil(M, ck::wrapper::size<0>(tile_shape)); + const ck::index_t grid_size_y = + ck::math::integer_divide_ceil(N, ck::wrapper::size<1>(tile_shape)); + + const auto kernel = + DeviceGemm; + const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, + kernel, + dim3(grid_size_x, grid_size_y, 1), + dim3(ck::wrapper::size(thread_layout)), + 0, + a_mem.GetDeviceBuffer(), + b_mem.GetDeviceBuffer(), + c_mem.GetDeviceBuffer(), + M, + N, + K, + tile_shape, + thread_layout); + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(DataType) * M * K + sizeof(DataType) * K * N + sizeof(DataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_btype / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << std::endl; + + std::vector c_data(M * N); + c_mem.FromDevice(c_data.data()); + CheckResult(a_data, b_data, c_data, M, N, K); +} + +TEST(TestGemm, Float) +{ + using DataType = float; + // (dim1, dim2, dim0 thread layout) + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<16>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Int8) +{ + using DataType = int8_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<64>{}); + PerformGemm(512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Half) +{ + using DataType = ck::half_t; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<128>{}, ck::Number<128>{}, ck::Number<32>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); + // Irregular case + PerformGemm( + 129, 129, 67, tile_shape, thread_layout); +} + +TEST(TestGemm, Float_2x4_4x2_XdlPerWave) +{ + using DataType = float; + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<64>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<1>{}, ck::Number<4>{}, ck::Number<1>{})); + const auto tile_shape = ck::make_tuple(ck::Number<256>{}, ck::Number<128>{}, ck::Number<16>{}); + PerformGemm( + 512, 512, 128, tile_shape, thread_layout); +} diff --git a/test/wrapper/test_layout.cpp b/test/wrapper/test_wrapper_layout.cpp similarity index 99% rename from test/wrapper/test_layout.cpp rename to test/wrapper/test_wrapper_layout.cpp index a128a6d84f..0b07303299 100644 --- a/test/wrapper/test_layout.cpp +++ b/test/wrapper/test_wrapper_layout.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/test/wrapper/test_partition.cpp b/test/wrapper/test_wrapper_partition.cpp similarity index 79% rename from test/wrapper/test_partition.cpp rename to test/wrapper/test_wrapper_partition.cpp index 8b6d220cd7..08d196c4ca 100644 --- a/test/wrapper/test_partition.cpp +++ b/test/wrapper/test_wrapper_partition.cpp @@ -29,8 +29,11 @@ TEST(TestPartition, LocalPartition) const auto tensor = ck::wrapper::make_tensor(data.data(), layout); - const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); - const auto thread_layout = ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}); + const auto thread_steps = ck::make_tuple(ck::Number<1>{}, ck::Number<8>{}, ck::Number<1>{}); + // row-major thread layout + const auto thread_layout = + ck::wrapper::make_layout(ck::make_tuple(ck::Number<4>{}, ck::Number<8>{}, ck::Number<1>{}), + ck::make_tuple(ck::Number<8>{}, ck::Number<1>{}, ck::Number<1>{})); // 3d partition on 2d shape (calculate partition on 3d thread layout, and then skip first dim) const auto thread_projection = ck::make_tuple(ck::wrapper::slice(4), ck::Number<1>{}, ck::Number<1>{}); @@ -70,29 +73,37 @@ TEST(TestPartition, LocalTile) ck::make_tuple(ck::Number<2>{}, ck::Number<4>{}, ck::Number<2>{}, ck::Number<2>{}); const auto block_projection = ck::make_tuple(ck::Number<1>{}, ck::Number<1>{}, ck::Number<1>{}, ck::wrapper::slice(2)); - constexpr ck::index_t projection_block_dim = ck::Number<2>{}; - const auto num_blocks = + + const auto grid_shape = ck::make_tuple(ck::wrapper::size<0>(shape) / ck::wrapper::size<0>(block_shape), ck::wrapper::size<1>(shape) / ck::wrapper::size<1>(block_shape), ck::wrapper::size<2>(shape) / ck::wrapper::size<2>(block_shape)); - std::vector block_idxs(ck::wrapper::size(num_blocks)); - std::iota(block_idxs.begin(), block_idxs.end(), 0); + std::vector> block_idxs; + for(int i = 0; i < ck::wrapper::size<0>(grid_shape); i++) + { + for(int j = 0; j < ck::wrapper::size<1>(grid_shape); j++) + { + for(int k = 0; k < ck::wrapper::size<2>(grid_shape); k++) + { + block_idxs.emplace_back(i, j, k, 0); + } + } + } for(auto block_idx : block_idxs) { + constexpr ck::index_t projection_block_dim = ck::Number<2>{}; const auto packed_tile = ck::wrapper::make_local_tile(tensor, block_shape, block_idx, block_projection); const auto expected_tile_size = ck::wrapper::size(block_shape) / projection_block_dim; - auto expected_tile_first_val = (block_idx % ck::wrapper::size<2>(num_blocks)) * + auto expected_tile_first_val = ck::wrapper::size<2>(block_idx) * ck::wrapper::size<2>(block_shape) * ck::wrapper::size<2>(strides); - block_idx /= ck::wrapper::size<2>(num_blocks); - expected_tile_first_val += (block_idx % ck::wrapper::size<1>(num_blocks)) * + expected_tile_first_val += ck::wrapper::size<1>(block_idx) * ck::wrapper::size<1>(block_shape) * ck::wrapper::size<1>(strides); - block_idx /= ck::wrapper::size<1>(num_blocks); - expected_tile_first_val += (block_idx % ck::wrapper::size<0>(num_blocks)) * + expected_tile_first_val += ck::wrapper::size<0>(block_idx) * ck::wrapper::size<0>(block_shape) * ck::wrapper::size<0>(strides); diff --git a/test/wrapper/test_tensor.cpp b/test/wrapper/test_wrapper_tensor.cpp similarity index 100% rename from test/wrapper/test_tensor.cpp rename to test/wrapper/test_wrapper_tensor.cpp From b9ab9f4b4fd3e5787216e291d6ffb485465c38d1 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 15 Feb 2024 15:46:01 -0800 Subject: [PATCH 07/36] upgrade the ccache version and update links (#1169) --- Dockerfile | 11 +++++++---- dev-requirements.txt | 2 +- docs/Contributors_Guide.rst | 8 ++++---- docs/dockerhub.rst | 4 ++-- docs/tutorial_hello_world.rst | 4 ++-- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/Dockerfile b/Dockerfile index 48ee97eec2..38f234943c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -44,7 +44,6 @@ ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ cmake \ - ccache \ git \ hip-rocclr \ iputils-ping \ @@ -74,6 +73,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +#Install latest ccache +RUN git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install + #Install ninja build tracing tools RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip RUN gunzip /usr/local/bin/ninja.gz @@ -111,7 +114,7 @@ ENV LANG=C.UTF-8 RUN groupadd -f render # Install the new rocm-cmake version -RUN git clone -b master https://github.com/RadeonOpenCompute/rocm-cmake.git && \ +RUN git clone -b master https://github.com/ROCm/rocm-cmake.git && \ cd rocm-cmake && mkdir build && cd build && \ cmake .. && cmake --build . && cmake --build . --target install @@ -123,7 +126,7 @@ RUN sh -c "echo compiler version = '$compiler_version'" RUN sh -c "echo compiler commit = '$compiler_commit'" RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" = "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ @@ -131,7 +134,7 @@ RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd fi RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline-open" ] ) && [ "$compiler_commit" != "" ]; then \ - git clone -b "$compiler_version" https://github.com/RadeonOpenCompute/llvm-project.git && \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ make -j 8 ; \ diff --git a/dev-requirements.txt b/dev-requirements.txt index d5d91f8c27..ca883c19e1 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,3 +1,3 @@ ROCm/rocm-recipes -RadeonOpenCompute/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build +ROCm/rocm-cmake@04f694df2a8dc9d7e35fa4dee4ba5fa407ec04f8 --build danmar/cppcheck@2.9 diff --git a/docs/Contributors_Guide.rst b/docs/Contributors_Guide.rst index b91984357a..3788ba609c 100644 --- a/docs/Contributors_Guide.rst +++ b/docs/Contributors_Guide.rst @@ -17,7 +17,7 @@ Getting started `Composable Kernel User Guide `_. It provides insight into the core concepts, environment configuration, and steps to obtain or build the library. You can also find some of this information in the - `README file `_ + `README file `_ on the project's GitHub page. #. **Additional reading:** The blog post `AMD Composable Kernel library: efficient fused kernels for AI apps with just a few lines of code `_ provides a deeper understanding of the CK library and showcases its performance capabilities. `_ @@ -33,7 +33,7 @@ You can make an impact by reporting issues or proposing code enhancements throug Reporting issues ---------------- -Use `Github issues `_ +Use `Github issues `_ to track public bugs and enhancement requests. If you encounter an issue with the library, please check if the problem has already been @@ -68,7 +68,7 @@ Creating Pull Requests ---------------------- You can submit `Pull Requests (PR) on GitHub -`_. +`_. All contributors are required to develop their changes on a separate branch and then create a pull request to merge their changes into the `develop` branch, which is the default @@ -89,7 +89,7 @@ When submitting a Pull Request you should: the project's root directory. We leverage `pre-commit` to run `clang-format` automatically. We highly recommend contributors utilize this method to maintain consistent code formatting. Instructions on setting up `pre-commit` can be found in the project's - `README file `_ + `README file `_ * Link your PR to any related issues: diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index fb89bef72b..21121f1b82 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -38,7 +38,7 @@ The docker images have everything you need for running CK including: * `ROCm `_ * `CMake `_ -* `Compiler `_ +* `Compiler `_ * `Composable Kernel library `_ Running the docker container @@ -97,5 +97,5 @@ Editing the docker image ======================= If you want to customize the docker image, edit the -`Dockerfile `_ +`Dockerfile `_ from the GitHub repository to suit your needs. diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial_hello_world.rst index d89331e579..c31460785b 100644 --- a/docs/tutorial_hello_world.rst +++ b/docs/tutorial_hello_world.rst @@ -32,7 +32,7 @@ CK library acceleration features are based on: If you need more technical details and benchmarking results read the following `blog post `_. -To download the library visit the `composable_kernel repository `_. +To download the library visit the `composable_kernel repository `_. Hardware targets ================ @@ -58,7 +58,7 @@ This tutorial is based on the use of docker images as explained in :ref:`docker- .. note:: - You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. + You can also `install ROCm `_ on your system, clone the `Composable Kernel repository `_ on GitHub, and use that to build and run the examples using the commands described below. Both the docker container and GitHub repository include the Composable Kernel library. Navigate to the library:: From abac8b07ddd5d8918f9f0d7b80ce16308e990c9d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 16 Feb 2024 07:48:52 -0800 Subject: [PATCH 08/36] Bump rocm-docs-core from 0.34.0 to 0.34.2 in /docs/sphinx (#1170) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.34.0 to 0.34.2. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.34.0...v0.34.2) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 65341af8d6..ae2cbe44ab 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.34.0 +rocm-docs-core==0.34.2 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 74016ea8a2..84f232fa2d 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.34.0 +rocm-docs-core==0.34.2 # via -r requirements.in six==1.16.0 # via From 66736edb95fb9e0250a2fd23ce75001c968caa73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 20 Feb 2024 18:56:54 +0100 Subject: [PATCH 09/36] Extend permute scale support up to 6D (#1168) * Extend permute scale support up to 6D * Fixes * Fixes * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa * Update profiler/README.md Co-authored-by: Lisa --------- Co-authored-by: Lisa --- .../impl/device_elementwise_scale_impl.hpp | 15 +- .../gpu/permute_scale.hpp | 195 ++++++++- .../device_permute_scale_instances.hpp} | 98 ++--- .../gpu/permute_scale/CMakeLists.txt | 7 +- .../device_permute_scale_1d_instances.cpp | 29 ++ .../device_permute_scale_2d_instances.cpp | 29 ++ .../device_permute_scale_3d_instances.cpp | 29 ++ .../device_permute_scale_4d_instances.cpp | 29 ++ .../device_permute_scale_5d_instances.cpp | 29 ++ .../device_permute_scale_6d_instances.cpp | 29 ++ profiler/README.md | 50 ++- .../profiler/profile_permute_scale_impl.hpp | 400 ++++++++---------- profiler/src/CMakeLists.txt | 2 + profiler/src/profile_permute_scale.cpp | 170 ++++++++ test/permute_scale/test_permute_scale.cpp | 86 +++- 15 files changed, 898 insertions(+), 299 deletions(-) rename library/{src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp => include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp} (50%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp rename test/permute_scale/test_permute_scale_impl.hpp => profiler/include/profiler/profile_permute_scale_impl.hpp (53%) create mode 100644 profiler/src/profile_permute_scale.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp index 5e0f5e288e..33d70b0b88 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -322,6 +322,19 @@ struct DeviceElementwiseImpl : public DeviceElementwise(); }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseNormalizationImpl<"; + str << NumDim << ", "; + str << MPerThread << ">"; + // clang-format on + + return str.str(); + } }; // namespace device } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp index 6ea1244c57..4b3f40e214 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,7 +17,32 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_f16_instances( +#ifdef CK_ENABLE_FP16 +void add_device_permute_scale_1d_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 1>>>&); + +void add_device_permute_scale_2d_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 2>>>&); + +void add_device_permute_scale_3d_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 3>>>&); + +void add_device_permute_scale_4d_f16_instances( std::vector, ck::Tuple, PassThrough, @@ -25,7 +50,50 @@ void add_device_permute_scale_f16_instances( Scale, 4>>>&); -void add_device_permute_scale_f32_instances( +void add_device_permute_scale_5d_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 5>>>&); + +void add_device_permute_scale_6d_f16_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 6>>>&); + +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_permute_scale_1d_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 1>>>&); + +void add_device_permute_scale_2d_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 2>>>&); + +void add_device_permute_scale_3d_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 3>>>&); + +void add_device_permute_scale_4d_f32_instances( std::vector, ck::Tuple, PassThrough, @@ -33,6 +101,23 @@ void add_device_permute_scale_f32_instances( Scale, 4>>>&); +void add_device_permute_scale_5d_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 5>>>&); + +void add_device_permute_scale_6d_f32_instances( + std::vector, + ck::Tuple, + PassThrough, + element_wise::UnarySquare, + Scale, + 6>>>&); +#endif + template > op_ptrs; - if constexpr(is_same_v> && - is_same_v>) + if constexpr(NumDim == 1) { - add_device_permute_scale_f32_instances(op_ptrs); +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_1d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_1d_f16_instances(op_ptrs); + } +#endif } - else if constexpr(is_same_v> && - is_same_v>) + else if constexpr(NumDim == 2) { - add_device_permute_scale_f16_instances(op_ptrs); +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_2d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_2d_f16_instances(op_ptrs); + } +#endif + } + else if constexpr(NumDim == 3) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_3d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_3d_f16_instances(op_ptrs); + } +#endif + } + else if constexpr(NumDim == 4) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_4d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_4d_f16_instances(op_ptrs); + } +#endif + } + else if constexpr(NumDim == 5) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_5d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_5d_f16_instances(op_ptrs); + } +#endif + } + else if constexpr(NumDim == 6) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_6d_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v> && + is_same_v>) + { + add_device_permute_scale_6d_f16_instances(op_ptrs); + } +#endif } return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp similarity index 50% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp rename to library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp index fbbedd52e8..a672ab22df 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.cpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp @@ -1,56 +1,42 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" -#include "ck/utility/data_type.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Pass = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; - -// clang-format off -using device_permute_scale_f16_instances = - std::tuple < - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> - >; - -using device_permute_scale_f32_instances = std::tuple< - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 8, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 4, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, 4, 2, ck::Sequence<1>, ck::Sequence<1>> - >; -// clang-format on - -void add_device_permute_scale_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances{}); -} - -void add_device_permute_scale_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f32_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Pass = ck::tensor_operation::element_wise::PassThrough; +using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; +using Scale = ck::tensor_operation::element_wise::Scale; + +// clang-format off +template +using device_permute_scale_f16_instances = + std::tuple < + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + >; + +template +using device_permute_scale_f32_instances = std::tuple< + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + >; +// clang-format on + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt index 8b45c1ab07..86652c0bf6 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt @@ -1,2 +1,7 @@ add_instance_library(device_permute_scale_instance - device_permute_scale_instances.cpp) + device_permute_scale_1d_instances.cpp + device_permute_scale_2d_instances.cpp + device_permute_scale_3d_instances.cpp + device_permute_scale_4d_instances.cpp + device_permute_scale_5d_instances.cpp + device_permute_scale_6d_instances.cpp) diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp new file mode 100644 index 0000000000..77d3baf4d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_1d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<1>{}); +} + +void add_device_permute_scale_1d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp new file mode 100644 index 0000000000..399b6b0490 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_2d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<2>{}); +} + +void add_device_permute_scale_2d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<2>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp new file mode 100644 index 0000000000..29f2f9fd5c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_3d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<3>{}); +} + +void add_device_permute_scale_3d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<3>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp new file mode 100644 index 0000000000..3ad1d59e66 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_4d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<4>{}); +} + +void add_device_permute_scale_4d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<4>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp new file mode 100644 index 0000000000..6a4383bc95 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_5d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<5>{}); +} + +void add_device_permute_scale_5d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<5>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp new file mode 100644 index 0000000000..71e5867e9a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_permute_scale_6d_f16_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f16_instances<6>{}); +} + +void add_device_permute_scale_6d_f32_instances( + std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<6>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/README.md b/profiler/README.md index e53f22754a..f26c90d0b3 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -37,9 +37,9 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s ################ op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads ./bin/ckProfiler conv2d_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ``` - Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` + +```bash in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} @@ -104,6 +104,7 @@ arg.b_grid_desc_k0_n0_n1_k1_{2048, 4096, 2} arg.e_grid_desc_m_n_{ 4096, 4096} .... Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s +``` ## Profile grouped convolution backward data kernels ```bash # arg1: tensor operation (grouped_conv_bwd_data: Grouped Convolution Backward Data) @@ -129,10 +130,11 @@ Best Perf: 58.0306 ms, 37.8942 TFlops, 27.7545 GB/s ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx ./bin/ckProfiler grouped_conv_bwd_data 1 0 1 1 0 1 2 32 4 192 192 3 3 28 28 1 1 1 1 1 1 1 1 - ``` +``` Result (MI100, FP16, GNHWC_GKYXC_GNHWK) -``` + +```bash out: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} wei: dim 5, lengths {32, 192, 192, 3, 3}, strides {331776, 1728, 1, 576, 192} in: dim 5, lengths {32, 4, 192, 28, 28}, strides {602112, 150528, 1, 5376, 192} @@ -173,10 +175,11 @@ GB/s: 127.947 ################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK ./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1 - ``` +``` Result (MI100, FP16, GNHWC_GKYXC_GNHWK) -``` + +```bash input: dim 5, lengths {32, 512, 1024, 28, 28}, strides {411041792, 802816, 1, 28672, 1024} weight: dim 5, lengths {32, 512, 1024, 3, 3}, strides {4718592, 9216, 1, 3072, 1024} output: dim 5, lengths {32, 512, 512, 26, 26}, strides {177209344, 346112, 1, 13312, 512} @@ -190,8 +193,9 @@ GB/s: 69.2301 Note: This kernel use atomic add, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. ## Profile image to column/column to image kernels + ```bash -# arg1: tensor operation (" OP_NAME ": " OP_DESC ") +# arg1: tensor operation ( conv_tensor_rearrange : Conv Tensor Rearrange ) # arg2: data type (0: Input fp32, Weight fp32, Output fp32 # 1: Input fp16, Weight fp16, Output fp16 # 2: Input bf16, Weight bf16, Output bf16 @@ -216,10 +220,11 @@ Note: This kernel use atomic add, this will cause output buffer to be accumulate ################ op datatype layout verify init log time opType Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx ./bin/ckProfiler conv_tensor_rearrange 0 0 0 1 0 1 0 2 1 256 1 512 3 3 28 28 1 1 1 1 0 0 0 0 - ``` +``` Result (MI210, FP32, NHWC) -``` + +```bash input: dim 5, lengths {1, 256, 512, 28, 28}, strides {102760448, 401408, 1, 14336, 512} output: dim 2, lengths {173056, 4608}, strides {4608, 1} .... @@ -229,3 +234,30 @@ avg_time: 3.12326 GB/s: 2042.59 ``` Note: Column to image kernel adds to the output memory, this will cause output buffer to be accumulated multiple times, causing verification failure. To work around it, do not use CK's own timer and do verification at the same time. + +## Profile Permute scale kernels + +```bash +# arg1: tensor operation ( permute_scale : Permute Scale ) +# arg2: data type (0: Input fp32, Output fp32 +# 1: Input fp16, Output fp16 +# arg4: verification (0: no, 1: yes) +# arg5: initialization (0: no init, 1: integer value, 2: decimal value) +# arg6: print tensor value (0: no; 1: yes) +# arg7: time kernel (0: no, 1: yes) +# from arg8: tensor lengths +# input strides +# output strides + +################ op datatype verify init log time dim0 dim1 dim2 in_stride0 in_stride1 in_stride2 out_stride0 out_stride1 out_stride2 +./bin/ckProfiler permute_scale 0 1 1 0 1 64 64 64 4096 64 1 1 64 4096 +``` + +Result (MI100, FP32) + +```bash +A: dim 3, lengths {64, 64, 64}, strides {4096, 64, 1} +B: dim 3, lengths {64, 64, 64}, strides {1, 64, 4096} +.... +Best perf = 0.0146878 ms, 142.782 GB/s, DeviceElementwiseNormalizationImpl<3, 2> +``` diff --git a/test/permute_scale/test_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp similarity index 53% rename from test/permute_scale/test_permute_scale_impl.hpp rename to profiler/include/profiler/profile_permute_scale_impl.hpp index 3837e7ef5a..5bc7c029f4 100644 --- a/test/permute_scale/test_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -1,212 +1,188 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" - -#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/utility/literals.hpp" - -namespace ck { -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) -{ - std::size_t N = A_nchw.mDesc.GetLengths()[0]; - std::size_t C = A_nchw.mDesc.GetLengths()[1]; - std::size_t H = A_nchw.mDesc.GetLengths()[2]; - std::size_t W = A_nchw.mDesc.GetLengths()[3]; - for(std::size_t w = 0; w < W; ++w) - for(std::size_t h = 0; h < H; ++h) - for(std::size_t c = 0; c < C; ++c) - for(std::size_t n = 0; n < N; ++n) - { - using tmp_type = ck::remove_reference_t; - tmp_type tmp_val = 0; - auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor_b(tmp_val, a_val); - functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], - scale * tmp_val); - } -} - -template -bool test_permute_scale_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - std::vector lengths) -{ - bool pass = true; - - using ElementOp = ck::tensor_operation::element_wise::PassThrough; - using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; - using Scale = ck::tensor_operation::element_wise::Scale; - float scale = 2.f; - - index_t N = lengths[0]; - index_t C = lengths[1]; - index_t H = lengths[2]; - index_t W = lengths[3]; - - std::vector nchw = {N, C, H, W}; - std::vector nhwc = {N, H, W, C}; - Tensor a(nchw); - Tensor b(nhwc); - Tensor host_b(nhwc); - - std::array ab_lengths; - - std::array a_strides = {1, - static_cast(nchw[0]), - static_cast(nchw[0] * nchw[1]), - static_cast(nchw[0] * nchw[1] * nchw[2])}; - - std::array b_strides = {1, - static_cast(nhwc[0] * nhwc[1] * nhwc[2]), - static_cast(nhwc[0]), - static_cast(nhwc[0] * nhwc[1])}; - ck::ranges::copy(nchw, ab_lengths.begin()); - - std::cout << "A: " << a.mDesc << std::endl; - std::cout << "B: " << b.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: a.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; - default: // a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0} - std::mt19937 gen(11939); - std::uniform_int_distribution dis(0, 1); - auto i = 0; - for(std::size_t w = 0; w < a.mDesc.GetLengths()[3]; ++w) - for(std::size_t h = 0; h < a.mDesc.GetLengths()[2]; ++h) - for(std::size_t c = 0; c < a.mDesc.GetLengths()[1]; ++c) - for(std::size_t n = 0; n < a.mDesc.GetLengths()[0]; ++n) - { - a.mData[(n * nchw[1] * nchw[2] * nchw[3]) + (c * nchw[2] * nchw[3]) + - (h * nchw[3]) + w] = i; - i = dis(gen); - } - } - - DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); - - a_device_buf.ToDevice(a.mData.data()); - - std::array input = {a_device_buf.GetDeviceBuffer()}; - std::array output = {b_device_buf.GetDeviceBuffer()}; - using DeviceOp = ck::tensor_operation::device::DeviceElementwise, - ck::Tuple, - ElementOp, - UnaryOp, - Scale, - NumDim>; - - // get device op instances - const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< - DeviceOp>::GetInstances(); - - std::cout << "found " << op_ptrs.size() << " instances" << std::endl; - - std::string best_instance_name; - float best_ave_time = std::numeric_limits::max(); - float best_gb_per_sec = 0; - float best_tflops = 0; - - if(do_verification) - { - host_elementwise4D(host_b, a, ElementOp{}, UnaryOp{}, scale); - } - - for(auto& op_ptr : op_ptrs) - { - auto argument_ptr = op_ptr->MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - ElementOp{}, - UnaryOp{}, - Scale{scale}); - - auto invoker_ptr = op_ptr->MakeInvokerPointer(); - - if(op_ptr->IsSupportedArgument(argument_ptr.get())) - { - b_device_buf.SetZero(); - - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); - - if(do_verification) - { - b_device_buf.FromDevice(b.mData.data()); - - pass &= ck::utils::check_err( - b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); - - if(do_log) - { - LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; - LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; - } - } - - std::string op_name = op_ptr->GetTypeString(); - - float ave_time = - invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * nchw[0] * nchw[1] * nchw[2] * nchw[3]; - - std::size_t num_btype = sizeof(ADataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]) + - sizeof(BDataType) * (nchw[0] * nchw[1] * nchw[2] * nchw[3]); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; - - if(tflops > best_tflops) - { - best_instance_name = op_name; - best_tflops = tflops; - best_ave_time = ave_time; - best_gb_per_sec = gb_per_sec; - } - } - else - { - std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; - } - } - if(time_kernel) - { - LogRange(std::cout << "length = ", lengths, ",") << ", "; - std::cout << "best perf = " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " - << best_instance_name << std::endl; - } - - return true; -} - -} // namespace ck +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" + +#include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" + +namespace ck { +template +void reference_permute_scale(HostTensorB& b_tensor, + const HostTensorA& a_tensor, + AElementOp a_tensor_op, + BElementOp b_tensor_op, + ScaleElementOp scale_op) +{ + b_tensor.ForEach([&](auto& self, auto idx) { + auto tmp_val = a_tensor(idx); + b_tensor_op(tmp_val, tmp_val); + scale_op(tmp_val, tmp_val); + a_tensor_op(self(idx), tmp_val); + }); +} + +namespace profiler { + +template +bool profile_permute_scale_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + std::vector lengths_vector, + std::vector input_strides_vector, + std::vector output_strides_vector) +{ + bool pass = true; + bool instance_found = false; + + using ElementOp = ck::tensor_operation::element_wise::PassThrough; + using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; + using Scale = ck::tensor_operation::element_wise::Scale; + float scale = 2.f; + + Tensor a(lengths_vector, input_strides_vector); + Tensor b(lengths_vector, output_strides_vector); + Tensor host_b(lengths_vector, output_strides_vector); + + std::cout << "A: " << a.mDesc << std::endl; + std::cout << "B: " << b.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: a.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; + default: a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a.mData.data()); + + std::array input = {a_device_buf.GetDeviceBuffer()}; + std::array output = {b_device_buf.GetDeviceBuffer()}; + using DeviceOp = ck::tensor_operation::device::DeviceElementwise, + ck::Tuple, + ElementOp, + UnaryOp, + Scale, + NumDim>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_instance_name; + float best_ave_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + if(do_verification) + { + reference_permute_scale(host_b, a, ElementOp{}, UnaryOp{}, Scale{scale}); + } + + auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; + std::array lengths{}; + std::array input_strides{}; + std::array output_strides{}; + copy(lengths_vector, lengths); + copy(input_strides_vector, input_strides); + copy(output_strides_vector, output_strides); + + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, + {input_strides}, + {output_strides}, + input, + output, + ElementOp{}, + UnaryOp{}, + Scale{scale}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + instance_found = true; + + b_device_buf.SetZero(); + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + + if(do_verification) + { + b_device_buf.FromDevice(b.mData.data()); + + pass &= ck::utils::check_err( + b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * a.mDesc.GetElementSpaceSize() / sizeof(ADataType); + + std::size_t num_btype = sizeof(ADataType) * a.mDesc.GetElementSpaceSize() + + sizeof(BDataType) * b.mDesc.GetElementSpaceSize(); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_instance_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + if(time_kernel) + { + std::cout << "Best perf = " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, " + << best_instance_name << std::endl; + } + + return pass && instance_found; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index c4b54d235f..f962d79900 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -32,6 +32,7 @@ set(PROFILER_SOURCES profile_grouped_conv_bwd_data.cpp profile_conv_tensor_rearrange.cpp profile_transpose.cpp + profile_permute_scale.cpp ) if(DL_KERNELS) @@ -99,6 +100,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_bwd_d target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_image_to_column_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance) +target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance) if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) diff --git a/profiler/src/profile_permute_scale.cpp b/profiler/src/profile_permute_scale.cpp new file mode 100644 index 0000000000..921b9b9a69 --- /dev/null +++ b/profiler/src/profile_permute_scale.cpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_permute_scale_impl.hpp" +#include "profiler_operation_registry.hpp" + +namespace { + +enum struct DataType +{ + F32_F32, // 0 + F16_F16 // 1 +}; + +#define OP_NAME "permute_scale" +#define OP_DESC "Permute Scale" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Output fp32\n" + << " 1: Input fp16, Output fp16\n" + << "arg4: verification (0: no, 1: yes)\n" + << "arg5: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0: no, 1: yes)\n" + << "from arg8: tensor lengths\n" + << " input strides\n" + << " output strides\n" << std::endl; + // clang-format on +} + +} // namespace + +int profile_permute_scale(int argc, char* argv[]) +{ + constexpr int control_argc = 7; + const int dims_argc = argc - control_argc; + // Number of lenghs, input strides and outputs strides must be equal + if(argc < control_argc && dims_argc % 3 != 0) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + const int num_dims = dims_argc / 3; + + std::vector lengths(num_dims); + std::vector input_strides(num_dims); + std::vector output_strides(num_dims); + + for(int i = 0; i < num_dims; i++) + { + lengths[i] = std::stoi(argv[control_argc + i]); + input_strides[i] = std::stoi(argv[control_argc + num_dims + i]); + output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]); + } + + using F32 = float; + using F16 = ck::half_t; + + constexpr auto I1 = ck::Number<1>{}; + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + constexpr auto I4 = ck::Number<4>{}; + constexpr auto I5 = ck::Number<5>{}; + constexpr auto I6 = ck::Number<6>{}; + + auto profile = [&](auto num_dim_tmp, auto in_type, auto out_type) { + constexpr ck::index_t NDim = num_dim_tmp.value; + + using InDataType = decltype(in_type); + using OutDataType = decltype(out_type); + + bool pass = + ck::profiler::profile_permute_scale_impl(do_verification, + init_method, + do_log, + time_kernel, + lengths, + input_strides, + output_strides); + + return pass ? 0 : 1; + }; + + if(num_dims == 1) + { + if(data_type == DataType::F32_F32) + { + return profile(I1, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I1, F16{}, F16{}); + } + } + else if(num_dims == 2) + { + if(data_type == DataType::F32_F32) + { + return profile(I2, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I2, F16{}, F16{}); + } + } + else if(num_dims == 3) + { + if(data_type == DataType::F32_F32) + { + return profile(I3, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I3, F16{}, F16{}); + } + } + else if(num_dims == 4) + { + if(data_type == DataType::F32_F32) + { + return profile(I4, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I4, F16{}, F16{}); + } + } + else if(num_dims == 5) + { + if(data_type == DataType::F32_F32) + { + return profile(I5, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I5, F16{}, F16{}); + } + } + else if(num_dims == 6) + { + if(data_type == DataType::F32_F32) + { + return profile(I6, F32{}, F32{}); + } + else if(data_type == DataType::F16_F16) + { + return profile(I6, F16{}, F16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_permute_scale); diff --git a/test/permute_scale/test_permute_scale.cpp b/test/permute_scale/test_permute_scale.cpp index 518d3fc87a..780f6d6edb 100644 --- a/test/permute_scale/test_permute_scale.cpp +++ b/test/permute_scale/test_permute_scale.cpp @@ -1,8 +1,8 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" -#include "test_permute_scale_impl.hpp" +#include "profiler/profile_permute_scale_impl.hpp" using F16 = ck::half_t; using F32 = float; @@ -15,15 +15,32 @@ class TestPermute : public ::testing::Test using ADataType = std::tuple_element_t<0, Tuple>; using BDataType = std::tuple_element_t<1, Tuple>; - void Run() + constexpr bool skip_case() { - std::vector> lengths = { - {4, 2, 1, 8}, {1, 1, 1, 1}, {16, 8, 32, 64}, {32, 64, 128, 128}}; - - for(auto length : lengths) +#ifndef CK_ENABLE_FP16 + if constexpr(ck::is_same_v || ck::is_same_v) { - bool success = - ck::test_permute_scale_impl(true, 2, false, false, length); + return true; + } +#endif +#ifndef CK_ENABLE_FP32 + if constexpr(ck::is_same_v || ck::is_same_v) + { + return true; + } +#endif + return false; + } + + template + void Run(std::vector lengths, + std::vector input_strides, + std::vector output_strides) + { + if(!skip_case()) + { + bool success = ck::profiler::profile_permute_scale_impl( + true, 2, false, false, lengths, input_strides, output_strides); EXPECT_TRUE(success); } } @@ -32,5 +49,52 @@ class TestPermute : public ::testing::Test using KernelTypes = ::testing::Types, std::tuple>; TYPED_TEST_SUITE(TestPermute, KernelTypes); -TYPED_TEST(TestPermute, Test_FP16) { this->Run(); } -TYPED_TEST(TestPermute, Test_FP32) { this->Run(); } +TYPED_TEST(TestPermute, Test1D) +{ + constexpr ck::index_t NumDims = 1; + this->template Run({8}, {1}, {2}); + this->template Run({8}, {2}, {1}); + this->template Run({1}, {1}, {1}); +} + +TYPED_TEST(TestPermute, Test2D) +{ + constexpr ck::index_t NumDims = 2; + this->template Run({8, 4}, {4, 1}, {1, 8}); + this->template Run({8, 4}, {1, 8}, {4, 1}); + this->template Run({1, 1}, {1, 1}, {1, 1}); +} + +TYPED_TEST(TestPermute, Test3D) +{ + constexpr ck::index_t NumDims = 3; + this->template Run({2, 4, 4}, {16, 4, 1}, {1, 2, 8}); + this->template Run({2, 4, 4}, {1, 2, 8}, {16, 4, 1}); + this->template Run({1, 1, 1}, {1, 1, 1}, {1, 1, 1}); +} + +TYPED_TEST(TestPermute, Test4D) +{ + constexpr ck::index_t NumDims = 4; + this->template Run({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32}); + this->template Run({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1}); + this->template Run({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}); +} + +TYPED_TEST(TestPermute, Test5D) +{ + constexpr ck::index_t NumDims = 5; + this->template Run({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128}); + this->template Run({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1}); + this->template Run({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); +} + +TYPED_TEST(TestPermute, Test6D) +{ + constexpr ck::index_t NumDims = 6; + this->template Run( + {2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512}); + this->template Run( + {2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1}); + this->template Run({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}); +} From 32d4be3d090830b565bb460ad1d5ea27e58cf956 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 21 Feb 2024 10:35:35 +0100 Subject: [PATCH 10/36] Add support for mixed precision bf16&int8 grouped gemm (#1166) * add support for mixed precision bf16&int8 grouped gemm * fix gfx versions and add bf16 kbatch condition * added reviewers comments --- client_example/22_grouped_gemm/CMakeLists.txt | 3 + .../grouped_gemm_fixed_nk_bf16.cpp | 237 +++++++++++ .../impl/device_grouped_gemm_xdl_fixed_nk.hpp | 61 ++- .../gpu/grouped_gemm_fixed_nk.hpp | 49 ++- .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 4 +- ...ixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp | 73 ++++ ...ixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp | 76 ++++ .../profile_grouped_gemm_fixed_nk_impl.hpp | 370 ++++++++++++++++++ profiler/src/CMakeLists.txt | 2 + .../src/profile_grouped_gemm_fixed_nk.cpp | 303 ++++++++++++++ 10 files changed, 1159 insertions(+), 19 deletions(-) create mode 100644 client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp create mode 100644 profiler/src/profile_grouped_gemm_fixed_nk.cpp diff --git a/client_example/22_grouped_gemm/CMakeLists.txt b/client_example/22_grouped_gemm/CMakeLists.txt index 19c613381e..0c3cb956f0 100644 --- a/client_example/22_grouped_gemm/CMakeLists.txt +++ b/client_example/22_grouped_gemm/CMakeLists.txt @@ -6,3 +6,6 @@ target_link_libraries(client_grouped_gemm_fixed_nk_fp8 PRIVATE composable_kernel add_executable(client_grouped_gemm_fixed_nk_i8 grouped_gemm_fixed_nk_i8.cpp) target_link_libraries(client_grouped_gemm_fixed_nk_i8 PRIVATE composable_kernel::device_gemm_operations) + +add_executable(client_grouped_gemm_fixed_nk_bf16 grouped_gemm_fixed_nk_bf16.cpp) +target_link_libraries(client_grouped_gemm_fixed_nk_bf16 PRIVATE composable_kernel::device_gemm_operations) diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp new file mode 100644 index 0000000000..9e8eac536b --- /dev/null +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = I8; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int main() +{ + std::vector Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; + + int sum_of_m = 0; + + const int group_count = 16; + + for(int i = 0; i < group_count; ++i) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(128 + 128 * i); + Ks.push_back(128 + 64 * i); + + StrideAs.push_back(std::is_same::value ? Ks[i] : Ms[i]); + StrideBs.push_back(std::is_same::value ? Ns[i] : Ks[i]); + StrideEs.push_back(std::is_same::value ? Ns[i] : Ms[i]); + + sum_of_m += Ms[i]; + } + + auto f_matrix_space_size = + [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { + using Layout = decltype(layout); + + if constexpr(std::is_same::value) + { + return (nRow - 1) * stride + nCol; + } + else + { + return (nCol - 1) * stride + nRow; + } + }; + + std::vector a_dev_bufs, b_dev_bufs, e_dev_bufs; + + a_dev_bufs.reserve(group_count); + b_dev_bufs.reserve(group_count); + e_dev_bufs.reserve(group_count); + + std::vector p_e; + + p_e.reserve(group_count); + + std::vector gemm_descs; + + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + a_dev_bufs.emplace_back(sizeof(ADataType) * + f_matrix_space_size(Ms[i], Ks[i], StrideAs[i], ALayout{})); + b_dev_bufs.emplace_back(sizeof(BDataType) * + f_matrix_space_size(Ks[i], Ns[i], StrideBs[i], BLayout{})); + e_dev_bufs.emplace_back(sizeof(EDataType) * + f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); + + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}}); + + p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_dev_bufs[i].GetDeviceBuffer(), + b_dev_bufs[i].GetDeviceBuffer(), + {}, + e_dev_bufs[i].GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideEs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{}; + + std::string best_op_name; + bool found = false; + int best_op_id = -1; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + std::vector p_a = {}, p_b = {}; + std::vector> p_ds = {}; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + + auto argument_ptr = op_ptr->MakeArgumentPointer( + p_a, p_b, p_ds, p_e, gemm_descs, a_element_op, b_element_op, cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + SimpleDeviceMem grouped_gemm_kernel_args_dev( + op_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + SimpleDeviceMem grouped_gemm_workspace_dev(op_ptr->GetWorkSpaceSize(argument_ptr.get())); + + std::string op_name = op_ptr->GetTypeString(); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + op_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + op_ptr->SetWorkSpacePointer(argument_ptr.get(), + grouped_gemm_workspace_dev.GetDeviceBuffer()); + + op_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + op_ptr->SetKBatch(argument_ptr.get(), 1); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = 0, num_btype = 0; + for(std::size_t j = 0; j < gemm_descs.size(); ++j) + { + flop += std::size_t(2) * Ms[j] * Ns[j] * Ks[j]; + + num_btype += sizeof(ADataType) * Ms[j] * Ks[j] + sizeof(BDataType) * Ks[j] * Ns[j] + + sizeof(EDataType) * Ms[j] * Ns[j]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + found = true; + best_op_id = i; + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 56cc8fb752..d197c56ab8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -650,22 +650,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK 1) - { - if(has_main_k_block_loop) - { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); - } - else - { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); - } - } - else + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced + // in IsSupportedArgument function + if constexpr(std::is_same::value) { if(has_main_k_block_loop) { @@ -678,6 +665,39 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK{}); } } + else + { + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + } return ave_time; } @@ -718,6 +738,13 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK::value) + { + supported = supported & (arg.k_batch_ == 1); + } + return supported; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index e8c368cb38..a90fe14603 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -97,6 +97,35 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( PassThrough, PassThrough>>>& instances); +// bf16_inputA i8_inputB +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) +void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances); +#endif + template && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index 3b48954d22..ac22543bef 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -5,6 +5,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16 device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp) + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..a88d2d7628 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; +using DsLayout = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..4dfff0db7c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I8 = int8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; +using DsLayout = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_irregular_tile_instances = std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Xdl_Fixed_NK< Row, Col, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp new file mode 100644 index 0000000000..5d2b7e0d9b --- /dev/null +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_grouped_gemm_fixed_nk_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::size_t group_count = Ms.size(); + + if(!(group_count == Ns.size() && group_count == Ks.size() && group_count == StrideAs.size() && + group_count == StrideBs.size() && group_count == StrideCs.size())) + { + throw std::runtime_error("wrong! inconsistent M/N/Ks, StrideA/B/Cs size\n"); + } + + std::vector> a_m_k; + std::vector> b_k_n; + std::vector> c_m_n_host_results; + std::vector> c_m_n_device_results; + + for(std::size_t i = 0; i < group_count; i++) + { + a_m_k.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); + b_k_n.push_back( + Tensor(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{}))); + + c_m_n_device_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); + + c_m_n_host_results.push_back( + Tensor(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}))); +#if DEBUG_LOG + std::cout << "group: " << i << " a_m_k[" << i << "]:" << a_m_k[i].mDesc << ", b_k_n[" << i + << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i + << "]:" << c_m_n_device_results[i].mDesc << std::endl; +#endif // DEBUG_LOG + std::size_t num_thread = 1; + switch(init_method) + { + case 0: break; + case 1: + a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + using DeviceMemPtr = std::unique_ptr; + std::vector a_device_buf, b_device_buf, c_device_buf; + + a_device_buf.reserve(group_count); + b_device_buf.reserve(group_count); + c_device_buf.reserve(group_count); + + std::vector p_a, p_b; + std::vector p_c; + + p_a.reserve(group_count); + p_b.reserve(group_count); + p_c.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + std::vector> + grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(std::size_t i = 0; i < group_count; i++) + { + a_device_buf.emplace_back( + std::make_unique(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize())); + b_device_buf.emplace_back( + std::make_unique(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize())); + c_device_buf.emplace_back(std::make_unique( + sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize())); + + a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); + b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); + + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + + p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); + p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); + p_c.push_back(c_device_buf[i]->GetDeviceBuffer()); + + grouped_gemm_kernel_args_.push_back({a_device_buf[i]->GetDeviceBuffer(), + b_device_buf[i]->GetDeviceBuffer(), + {}, + c_device_buf[i]->GetDeviceBuffer(), + Ms[i], + Ns[i], + Ks[i], + StrideAs[i], + StrideBs[i], + {}, + StrideCs[i]}); + } + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedGemmFixedNK, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AElementOp, + BElementOp, + CElementOp>; + + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + if(op_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + auto p_ds = std::vector>{}; + + if(do_verification) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k[i], + b_k_n[i], + c_m_n_host_results[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + } + + // profile device GEMM instances + for(auto& gemm_ptr : op_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_c, + gemm_descs, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}, + ck::tensor_operation::element_wise::PassThrough{}); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + DeviceMem gemm_desc_workspace(gemm_ptr->GetWorkSpaceSize(argument_ptr.get())); + + DeviceMem grouped_gemm_kernel_args_dev( + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get())); + + hipGetErrorString(hipMemcpy(grouped_gemm_kernel_args_dev.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm_ptr->GetDeviceKernelArgSize(argument_ptr.get()), + hipMemcpyHostToDevice)); + + gemm_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer()); + + gemm_ptr->SetDeviceKernelArgs(argument_ptr.get(), + grouped_gemm_kernel_args_dev.GetDeviceBuffer()); + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::vector kbatch_list = {1, 2, 4, 8, 12, 16, 20, 24, 32, 48, 64}; + + if(kbatch > 0) + { + kbatch_list = {kbatch}; + } + + for(std::size_t j = 0; j < kbatch_list.size(); j++) + { + + auto kbatch_curr = kbatch_list[j]; + + gemm_ptr->SetKBatch(argument_ptr.get(), kbatch_curr); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + for(std::size_t i = 0; i < gemm_descs.size(); i++) + c_device_buf[i]->SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + bool instance_pass = true; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + + c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); + + if(std::is_same_v && kbatch_curr > 1) + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + 0.06); + } + else + { + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i]); + } + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k[i].mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_results[i].mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_results[i].mData, ",") + << std::endl; + } + } + + std::cout << "Instance: " << gemm_name << " verification " + << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; + + pass = pass && instance_pass; + } + + float ave_time = invoker_ptr->Run( + argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, n_warmup, n_iter}); + + if(time_kernel) + { + std::size_t flop = 0, num_btype = 0; + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + flop += std::size_t(2) * Ms[i] * Ns[i] * Ks[i]; + + num_btype += sizeof(ADataType) * Ms[i] * Ks[i] + + sizeof(BDataType) * Ks[i] * Ns[i] + + sizeof(CDataType) * Ms[i] * Ns[i]; + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << gemm_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + } + else + { + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + << std::endl; + } + } + } + + if(time_kernel) + { + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch + << std::endl; + } + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index f962d79900..11ae285167 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -52,6 +52,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) list(APPEND PROFILER_SOURCES profile_gemm_add_relu_add_layernorm.cpp) list(APPEND PROFILER_SOURCES profile_batched_gemm_add_relu_gemm_add.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm.cpp) + list(APPEND PROFILER_SOURCES profile_grouped_gemm_fixed_nk.cpp) list(APPEND PROFILER_SOURCES profile_grouped_gemm_fastgelu.cpp) endif() @@ -126,6 +127,7 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_gemm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_add_relu_gemm_add_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_instance) + target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fixed_nk_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) endif() diff --git a/profiler/src/profile_grouped_gemm_fixed_nk.cpp b/profiler/src/profile_grouped_gemm_fixed_nk.cpp new file mode 100644 index 0000000000..3d280c2f43 --- /dev/null +++ b/profiler/src/profile_grouped_gemm_fixed_nk.cpp @@ -0,0 +1,303 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 +}; + +enum struct GemmDataType +{ + BF16_I8_BF16, // 0 + F16_F16_F16, // 1 + F16_F8_F16, // 2 + F16_I8_F16, // 3 + +}; + +#define OP_NAME "grouped_gemm_fixed_nk" +#define OP_DESC "Grouped GEMM Fixed NK" + +namespace { + +std::vector argToIntArray(char* input) +{ + std::vector out; + + std::istringstream in(input); + + std::string item; + + while(std::getline(in, item, ',')) + { + out.push_back(std::stoi(item)); + } + + return out; +} + +int profile_grouped_gemm_fixed_nk(int argc, char* argv[]) +{ + if(argc < 14) + { + std::cout + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: bf16@int8; 1: fp16; 2: fp16@fp8; 3: fp16@int8)\n" + << "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n" + << " 1: A[m, k] * B[n, k] = C[m, n];\n" + << "arg4: verification (0: no; 1: yes)\n" + << "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n" + << "arg6: print tensor value (0: no; 1: yes)\n" + << "arg7: time kernel (0=n0, 1=yes)\n" + << "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 " + "64,64 64,64 128,128)\n" + << "arg15: kbatch value (default 1)\n" + << "optional:\n" + << "arg16: number of warm-up cycles (default 1)\n" + << "arg17: number of iterations (default 10)\n" + << std::endl; + + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const auto Ms = argToIntArray(argv[8]); + const auto Ns = argToIntArray(argv[9]); + const auto Ks = argToIntArray(argv[10]); + + const auto StrideAs = argToIntArray(argv[11]); + const auto StrideBs = argToIntArray(argv[12]); + const auto StrideCs = argToIntArray(argv[13]); + const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1; + + using F32 = float; + using F16 = ck::half_t; + using F8 = ck::f8_t; + using BF16 = ck::bhalf_t; + using I8 = int8_t; + + int n_warmup = 1; + int n_iter = 10; + if(argc == 17) + { + n_warmup = std::stoi(argv[16]); + n_iter = std::stoi(argv[17]); + } + +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) + if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::BF16_I8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } +#endif +#if defined(CK_ENABLE_FP16) + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } +#endif +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } +#endif +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_INT8) + else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + else if(data_type == GemmDataType::F16_I8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_grouped_gemm_fixed_nk_impl( + do_verification, + init_method, + do_log, + time_kernel, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + 1, + n_warmup, + n_iter); + } +#endif + else + { + throw std::runtime_error("wrong! this GEMM data_type & layout is not implemented"); + } + return 0; +} + +} // anonymous namespace + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_grouped_gemm_fixed_nk); From 2eb74a9c0c86c832a75d7ebf1c9e899142ffac7a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 21:39:16 -0800 Subject: [PATCH 11/36] Bump rocm-docs-core from 0.34.2 to 0.35.0 in /docs/sphinx (#1175) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.34.2 to 0.35.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.34.2...v0.35.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index ae2cbe44ab..1576e54537 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.34.2 +rocm-docs-core==0.35.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 84f232fa2d..a8cb087225 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.34.2 +rocm-docs-core==0.35.0 # via -r requirements.in six==1.16.0 # via From d909599729a43128231a390ed363f5040eb5c5db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Mon, 26 Feb 2024 14:56:06 +0100 Subject: [PATCH 12/36] Remove unnecessary comments (#1177) --- client_example/25_wrapper/wrapper_basic_gemm.cpp | 1 - client_example/25_wrapper/wrapper_optimized_gemm.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/client_example/25_wrapper/wrapper_basic_gemm.cpp b/client_example/25_wrapper/wrapper_basic_gemm.cpp index 1f1a4de751..59c5c243ce 100644 --- a/client_example/25_wrapper/wrapper_basic_gemm.cpp +++ b/client_example/25_wrapper/wrapper_basic_gemm.cpp @@ -213,4 +213,3 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } -// MI300X Perf: 0.471337 ms, 273.369 TFlops, 204.671 GB/s, diff --git a/client_example/25_wrapper/wrapper_optimized_gemm.cpp b/client_example/25_wrapper/wrapper_optimized_gemm.cpp index ddf01de612..b6294c2393 100644 --- a/client_example/25_wrapper/wrapper_optimized_gemm.cpp +++ b/client_example/25_wrapper/wrapper_optimized_gemm.cpp @@ -305,4 +305,3 @@ int main(int argc, char* argv[]) 3840, 4096, 4096, tile_shape, thread_layout); return 0; } -// MI300X Perf: 0.411552 ms, 313.081 TFlops, 234.403 GB/s, From d0c7b45150695c2dff205c4ddc9cce2a2e6a2950 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:31:05 -0800 Subject: [PATCH 13/36] Clip fp8 to +/-240 on all targets. (#1172) * clip fp8 to +/-240 on all targets * if inputs to fp8 conversion are +/-inf, they remain unaltered * increase tolerance for test_elementwise_layernorm to prevent false errors * change the input values for gemm examples to floats * reduce gemm example float input values to prevent errors * increase the tolerance for gemm examples --- example/01_gemm/common.hpp | 2 +- example/01_gemm/run_gemm_example.inc | 7 ++++--- include/ck/utility/type_convert.hpp | 18 ++++++++++-------- .../profile_elementwise_layernorm_impl.hpp | 2 +- 4 files changed, 16 insertions(+), 13 deletions(-) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 7fd15b2833..eb281af7bb 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -49,7 +49,7 @@ struct ProblemSizeStreamK final struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 7be2539d90..49743a9c43 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -69,8 +69,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; default: - ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); - ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); + ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); @@ -240,7 +240,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); + return ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1); #endif } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 6bbff98312..b989094c0e 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -107,11 +107,12 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); + float max_fp8 = 240.0f; + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) - float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); union { float fval; @@ -144,7 +145,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -156,7 +157,7 @@ inline __host__ __device__ f8_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_t f8_convert_sr(float x) { - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); #if defined(__gfx94__) union @@ -191,7 +192,7 @@ inline __host__ __device__ bf8_t f8_convert_sr(half_t x) constexpr bool negative_zero_nan = true; constexpr bool clip = true; constexpr f8_rounding_mode rm = f8_rounding_mode::stochastic; - constexpr int seed = 42; + constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); return utils:: cast_to_f8( @@ -207,9 +208,10 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { -#if defined(__gfx94__) float max_fp8 = 240.0f; - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); + if(!std::isinf(x)) + x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); +#if defined(__gfx94__) union { float fval; diff --git a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp index ae42919db6..220076465d 100644 --- a/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp +++ b/profiler/include/profiler/profile_elementwise_layernorm_impl.hpp @@ -233,7 +233,7 @@ bool profile_elementwise_layernorm_impl(int do_verification, y_dev.FromDevice(y.mData.data()); bool pass = - ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 1e-3, 1e-3); + ck::utils::check_err(y.mData, host_y.mData, "Error: Incorrect results", 5e-3, 5e-3); if(do_log) { From a776978cbe45f8bbfe2b2ecd88c9f5308efdb169 Mon Sep 17 00:00:00 2001 From: amoskvic <158011354+amoskvic@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:39:03 -0700 Subject: [PATCH 14/36] Style improvement: improving type alias usage consistency in gemm-related client examples. Also copyright year update for all client examples. (#1180) Co-authored-by: Arseny Moskvichev --- client_example/01_gemm/gemm.cpp | 4 ++-- .../02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp | 4 ++-- .../gemm_add_add_fastgelu_generic.cpp | 4 ++-- .../02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp | 4 ++-- .../02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp | 4 ++-- client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp | 4 ++-- .../02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp | 4 ++-- .../03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp | 6 ++++-- .../gemm_add_relu_add_layernorm_welford.cpp | 4 ++-- client_example/04_contraction/contraction_bilinear_fp32.cpp | 2 +- client_example/04_contraction/contraction_bilinear_fp64.cpp | 2 +- .../04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp | 2 +- client_example/04_contraction/contraction_scale_fp32.cpp | 2 +- client_example/04_contraction/contraction_scale_fp64.cpp | 2 +- client_example/05_layernorm/layernorm2d_bwd_data.cpp | 2 +- client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp | 2 +- client_example/05_layernorm/layernorm2d_fwd.cpp | 2 +- client_example/05_layernorm/layernorm4d_fwd.cpp | 2 +- client_example/06_softmax/softmax4d.cpp | 2 +- client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp | 2 +- client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp | 2 +- client_example/08_fused_attention/fused_attention.cpp | 2 +- client_example/08_fused_attention/fused_attention_bias.cpp | 2 +- .../conv2d_fwd_bias_relu_perchannel_quantization.cpp | 2 +- .../conv2d_fwd_bias_relu_perlayer_quantization.cpp | 2 +- .../conv2d_fwd_bias_tanh_perchannel_quantization.cpp | 2 +- .../conv2d_fwd_bias_tanh_perlayer_quantization.cpp | 2 +- .../09_quantization/conv2d_fwd_perchannel_quantization.cpp | 2 +- .../09_quantization/conv2d_fwd_perlayer_quantization.cpp | 2 +- client_example/09_quantization/gemm_quantization.cpp | 2 +- .../10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp | 2 +- .../10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp | 2 +- .../grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp | 2 +- .../grouped_conv1d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv2d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp32.cpp | 2 +- .../elementwise_layernorm2d.cpp | 2 +- client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp | 2 +- client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp | 2 +- client_example/13_batchnorm/batchnorm_infer_nhwc.cpp | 2 +- client_example/14_instance_id/batchnorm_fwd_instance_id.cpp | 2 +- client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp | 2 +- client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp | 2 +- client_example/15_gemm_add_multiply/gemm_add_multiply.cpp | 2 +- client_example/15_reduce/reduce_nhwc_c.cpp | 2 +- client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp | 2 +- client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp | 2 +- client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp | 2 +- .../17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp | 2 +- client_example/18_groupnorm/groupnorm_bwd_data.cpp | 2 +- client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp | 2 +- client_example/18_groupnorm/groupnorm_swish_fwd.cpp | 2 +- client_example/19_pool/avg_pool3d_bwd.cpp | 2 +- client_example/19_pool/avg_pool3d_fwd.cpp | 2 +- client_example/19_pool/max_pool2d_bwd.cpp | 2 +- client_example/19_pool/max_pool2d_fwd.cpp | 2 +- client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp | 4 ++-- .../grouped_gemm_fixed_nk_bias_fp16.cpp | 4 ++-- .../22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp | 2 +- .../22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp | 4 ++-- .../22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp | 4 ++-- client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp | 4 ++-- client_example/22_im2col_col2im/image_to_column.cpp | 2 +- .../23_elementwise_transpose/elementwise_transpose_3d.cpp | 2 +- 66 files changed, 82 insertions(+), 80 deletions(-) diff --git a/client_example/01_gemm/gemm.cpp b/client_example/01_gemm/gemm.cpp index 11f9222873..e63cda6162 100644 --- a/client_example/01_gemm/gemm.cpp +++ b/client_example/01_gemm/gemm.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -83,7 +83,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp index e845c120d8..5809681661 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -92,7 +92,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp index 2ed942f0ad..3cc4313aab 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_add_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -93,7 +93,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp index e77b67c905..1fd80d10c7 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp index 644b428fc9..e54bcfd989 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_add_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -89,7 +89,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp index 7648da9cac..47fd58f691 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -84,7 +84,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp index 482e93b421..f43554f2bd 100644 --- a/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp +++ b/client_example/02_gemm_add_add_fastgelu/gemm_fastgelu_generic.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -85,7 +85,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp index 58c91f903b..cbadd9cf76 100644 --- a/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_add_layernorm_naive.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -17,6 +17,8 @@ using F16 = ck::half_t; using F32 = float; +using Row = ck::tensor_layout::gemm::RowMajor; + using ADataType = F16; using BDataType = F16; using BiasDataType = F32; @@ -191,7 +193,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp index 93f8847c62..7d5ef5f9bf 100644 --- a/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp +++ b/client_example/03_gemm_layernorm/gemm_add_relu_add_layernorm_welford.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -78,7 +78,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/04_contraction/contraction_bilinear_fp32.cpp b/client_example/04_contraction/contraction_bilinear_fp32.cpp index 89f834b982..f1881e60a0 100644 --- a/client_example/04_contraction/contraction_bilinear_fp32.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_bilinear_fp64.cpp b/client_example/04_contraction/contraction_bilinear_fp64.cpp index 1aa3ba7de5..8b499eee21 100644 --- a/client_example/04_contraction/contraction_bilinear_fp64.cpp +++ b/client_example/04_contraction/contraction_bilinear_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp index f8ea2258c2..a5ef40a2dc 100644 --- a/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp +++ b/client_example/04_contraction/contraction_g1m2n3k1_add_xdl_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp32.cpp b/client_example/04_contraction/contraction_scale_fp32.cpp index ba7b0633c3..5c06d31488 100644 --- a/client_example/04_contraction/contraction_scale_fp32.cpp +++ b/client_example/04_contraction/contraction_scale_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/04_contraction/contraction_scale_fp64.cpp b/client_example/04_contraction/contraction_scale_fp64.cpp index 24e52eb5aa..14fb8741e7 100644 --- a/client_example/04_contraction/contraction_scale_fp64.cpp +++ b/client_example/04_contraction/contraction_scale_fp64.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm2d_bwd_data.cpp b/client_example/05_layernorm/layernorm2d_bwd_data.cpp index 9f26cb6840..ec02cb2c4e 100644 --- a/client_example/05_layernorm/layernorm2d_bwd_data.cpp +++ b/client_example/05_layernorm/layernorm2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp index 98b394add6..1d1ebefd5b 100644 --- a/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp +++ b/client_example/05_layernorm/layernorm2d_bwd_gamma_beta.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm2d_fwd.cpp b/client_example/05_layernorm/layernorm2d_fwd.cpp index 420225b613..22599f43ca 100644 --- a/client_example/05_layernorm/layernorm2d_fwd.cpp +++ b/client_example/05_layernorm/layernorm2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/05_layernorm/layernorm4d_fwd.cpp b/client_example/05_layernorm/layernorm4d_fwd.cpp index fa408dc751..c80fd31b6e 100644 --- a/client_example/05_layernorm/layernorm4d_fwd.cpp +++ b/client_example/05_layernorm/layernorm4d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/06_softmax/softmax4d.cpp b/client_example/06_softmax/softmax4d.cpp index a62af76635..eaddbf98ee 100644 --- a/client_example/06_softmax/softmax4d.cpp +++ b/client_example/06_softmax/softmax4d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp index 4d743a66f0..4983ac33c3 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index c5e51ad993..9383350629 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention.cpp b/client_example/08_fused_attention/fused_attention.cpp index df6bc11a70..339d92e756 100644 --- a/client_example/08_fused_attention/fused_attention.cpp +++ b/client_example/08_fused_attention/fused_attention.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/08_fused_attention/fused_attention_bias.cpp b/client_example/08_fused_attention/fused_attention_bias.cpp index 6c9f3bc8f6..a1200a9db4 100644 --- a/client_example/08_fused_attention/fused_attention_bias.cpp +++ b/client_example/08_fused_attention/fused_attention_bias.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index 78db4f8aa5..08919401cd 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index 4121e41af7..1d502ba4a2 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp index ea5f1dbd5b..5b9c9d3708 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 5b40298d6c..7c40aa4e60 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp index 0b78bbf272..3777cd5e1b 100644 --- a/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perchannel_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp index 7315f2bb55..1fbb1ddea4 100644 --- a/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_perlayer_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/09_quantization/gemm_quantization.cpp b/client_example/09_quantization/gemm_quantization.cpp index b14e68fa08..d2fadd8d91 100644 --- a/client_example/09_quantization/gemm_quantization.cpp +++ b/client_example/09_quantization/gemm_quantization.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index 1b2e8abc20..ae5f1b6f6e 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index d2f2ff41bc..93709a7901 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index 2330228d1d..a62a1d911b 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp index e6d427faf4..a51aab483e 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 4201ea61b4..705ad21ae8 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index 3ae46bcd55..5ed3896e7a 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp index 098b7cd868..868e0e2903 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index 2eb869f392..d5f1fc331b 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 8326f0758c..69d7c8936c 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp index 1ed36e0f50..4f6985a514 100644 --- a/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_bwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp index f9af011c84..9fa82523be 100644 --- a/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp index 5e6627ce14..6393cf3e65 100644 --- a/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp +++ b/client_example/13_batchnorm/batchnorm_infer_nhwc.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp index d45782d8e0..2a565738a7 100644 --- a/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp +++ b/client_example/14_instance_id/batchnorm_fwd_instance_id.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp index 5210567241..29dbc97f40 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp index 441bdfe7be..b53e892fdc 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp index cde4713b23..a8c2ae1214 100644 --- a/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp +++ b/client_example/15_gemm_add_multiply/gemm_add_multiply.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index b45b72f0de..e2b1fbcb54 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp index d4455df628..10033822dd 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp index 1651ec2f39..22ba25efb9 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16_comp_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 7e8c98b603..a739f9d05b 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp index 7ba3224fc3..6a745e1ab0 100644 --- a/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp +++ b/client_example/17_grouped_gemm_fastgelu/grouped_gemm_fastgelu.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/groupnorm_bwd_data.cpp b/client_example/18_groupnorm/groupnorm_bwd_data.cpp index 01ca21ba57..bcfa5f7dc6 100644 --- a/client_example/18_groupnorm/groupnorm_bwd_data.cpp +++ b/client_example/18_groupnorm/groupnorm_bwd_data.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp index c2fbe285df..06ab194a8e 100644 --- a/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp +++ b/client_example/18_groupnorm/groupnorm_bwd_gamma_beta.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/18_groupnorm/groupnorm_swish_fwd.cpp b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp index d10d16bf9d..26110193d7 100644 --- a/client_example/18_groupnorm/groupnorm_swish_fwd.cpp +++ b/client_example/18_groupnorm/groupnorm_swish_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/19_pool/avg_pool3d_bwd.cpp b/client_example/19_pool/avg_pool3d_bwd.cpp index 686d1da3ad..0bf4b9346e 100644 --- a/client_example/19_pool/avg_pool3d_bwd.cpp +++ b/client_example/19_pool/avg_pool3d_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/19_pool/avg_pool3d_fwd.cpp b/client_example/19_pool/avg_pool3d_fwd.cpp index 6739a41b2f..846bd5ff4d 100644 --- a/client_example/19_pool/avg_pool3d_fwd.cpp +++ b/client_example/19_pool/avg_pool3d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/19_pool/max_pool2d_bwd.cpp b/client_example/19_pool/max_pool2d_bwd.cpp index 53ece7425f..a90889656d 100644 --- a/client_example/19_pool/max_pool2d_bwd.cpp +++ b/client_example/19_pool/max_pool2d_bwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/19_pool/max_pool2d_fwd.cpp b/client_example/19_pool/max_pool2d_fwd.cpp index 84b818a60f..99087b47d3 100644 --- a/client_example/19_pool/max_pool2d_fwd.cpp +++ b/client_example/19_pool/max_pool2d_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp index a740c22f91..5ace2e3056 100644 --- a/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp +++ b/client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -88,7 +88,7 @@ int main(int argc, char* argv[]) [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp index c758720e10..fa08f49e7d 100644 --- a/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp +++ b/client_example/21_grouped_gemm_bias/grouped_gemm_fixed_nk_bias_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -79,7 +79,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp index 9e8eac536b..92311b484a 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_bf16.cpp @@ -77,7 +77,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp index b16fe90387..9dc5564fca 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -76,7 +76,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp index 045fe47c4f..3519e48aa6 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_fp8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,7 +77,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp index 8f82140f3f..d77f411a32 100644 --- a/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp +++ b/client_example/22_grouped_gemm/grouped_gemm_fixed_nk_i8.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -77,7 +77,7 @@ int main() [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { using Layout = decltype(layout); - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { return (nRow - 1) * stride + nCol; } diff --git a/client_example/22_im2col_col2im/image_to_column.cpp b/client_example/22_im2col_col2im/image_to_column.cpp index 8eafbdc5bb..0ceedd7862 100644 --- a/client_example/22_im2col_col2im/image_to_column.cpp +++ b/client_example/22_im2col_col2im/image_to_column.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include diff --git a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp index 65ba46fcd2..82d7de2a7d 100644 --- a/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp +++ b/client_example/23_elementwise_transpose/elementwise_transpose_3d.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include From acfb33923813e4374060c16f5490558efdadfa58 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 1 Mar 2024 12:30:38 -0600 Subject: [PATCH 15/36] Update clipping for fp8/bf8 conversion (#1182) * Update clipping for fp8 conversion * Add clipping for bf8 conversion * Format --- include/ck/utility/type_convert.hpp | 50 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index b989094c0e..dbac1f0c85 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -109,9 +109,6 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) { constexpr int seed = 1254739; uint32_t rng = prand_generator(reinterpret_cast(&x), x); - float max_fp8 = 240.0f; - if(!std::isinf(x)) - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) union { @@ -119,10 +116,15 @@ inline __host__ __device__ f8_t f8_convert_sr(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; + val.fval = x; + uint32_t ival = 0; + const float max_fp8 = 240.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); + ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; @@ -166,10 +168,15 @@ inline __host__ __device__ bf8_t f8_convert_sr(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; - ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos - val.i32val = ival; + val.fval = x; + uint32_t ival = 0; + const float max_bf8 = 57344.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); + ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; return val.i8val[0]; // little endian #else constexpr bool negative_zero_nan = true; @@ -208,9 +215,6 @@ __host__ __device__ constexpr Y f8_convert_rne(X x); template <> inline __host__ __device__ f8_t f8_convert_rne(float x) { - float max_fp8 = 240.0f; - if(!std::isinf(x)) - x = x > max_fp8 ? max_fp8 : (x < -max_fp8 ? -max_fp8 : x); #if defined(__gfx94__) union { @@ -218,8 +222,13 @@ inline __host__ __device__ f8_t f8_convert_rne(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; + val.fval = x; + uint32_t ival = 0; + const float max_fp8 = 240.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; return val.i8val[0]; @@ -263,8 +272,13 @@ inline __host__ __device__ bf8_t f8_convert_rne(float x) uint32_t i32val; uint8_t i8val[4]; // not endian independent } val; - val.fval = x; - uint32_t ival = 0; + val.fval = x; + uint32_t ival = 0; + const float max_bf8 = 57344.0f; + // if x is not +/- infinity or nan + if((val.i32val & NumericUtils::nan_mask) != NumericUtils::Inf) + // clip float value + val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; return val.i8val[0]; From 9ce18b045d6ffa6cfa29134229b422c07984ffb7 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 1 Mar 2024 18:42:15 -0600 Subject: [PATCH 16/36] Fix example_gemm_xdl_fp8 (#1183) --- example/01_gemm/gemm_xdl_fp8.cpp | 14 +++++++++----- example/01_gemm/gemm_xdl_fp8_bf8.cpp | 8 ++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp8.cpp b/example/01_gemm/gemm_xdl_fp8.cpp index 2d4df3fc13..7d8538681b 100644 --- a/example/01_gemm/gemm_xdl_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp8.cpp @@ -20,14 +20,18 @@ using BElementOp = PassThrough; using CElementOp = PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto LoopSched = ck::make_default_loop_scheduler(); +static constexpr auto PipelineVer = ck::PipelineVersion::v1; +using ComputeTypeA = ck::f8_t; +using ComputeTypeB = ck::f8_t; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp8_bf8.cpp b/example/01_gemm/gemm_xdl_fp8_bf8.cpp index b54df8ff3d..acc5fbc515 100644 --- a/example/01_gemm/gemm_xdl_fp8_bf8.cpp +++ b/example/01_gemm/gemm_xdl_fp8_bf8.cpp @@ -27,10 +27,10 @@ using ComputeTypeB = ck::bf8_t; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| Compute| Compute| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| TypeA| TypeB| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, LoopSched, PipelineVer, ComputeTypeA, ComputeTypeB>; // clang-format on From cf866211702a2124502df26fcaf4931f7218f2d8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:42:16 -0800 Subject: [PATCH 17/36] [CI] Add CI build and test stage on MI300. (#1185) --- Jenkinsfile | 92 +++++++++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 41 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index becdc35b16..3cac20fd34 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1,5 +1,5 @@ def rocmnode(name) { - return '(rocmtest || miopen) && ' + name + return '(rocmtest || miopen) && (' + name + ')' } def show_node_info() { @@ -7,6 +7,7 @@ def show_node_info() { echo "NODE_NAME = \$NODE_NAME" lsb_release -sd uname -r + cat /sys/module/amdgpu/version ls /opt/ -la """ } @@ -33,6 +34,10 @@ def runShell(String command){ def getDockerImageName(){ def img + if (params.USE_CUSTOM_DOCKER != ""){ + img = "${params.USE_CUSTOM_DOCKER}" + } + else{ if (params.ROCMVERSION != "6.0.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" @@ -61,6 +66,7 @@ def getDockerImageName(){ } } } + } return img } @@ -365,8 +371,8 @@ def runCKProfiler(Map conf=[:]){ (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + sh 'rocminfo | tee rocminfo.log' + if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") } else{ @@ -379,20 +385,6 @@ def runCKProfiler(Map conf=[:]){ echo "The job was cancelled or aborted" throw e } - catch(Exception ex) { - retimage = docker.build("${image}", dockerArgs + " --no-cache .") - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ - throw new Exception ("GPU not found") - } - else{ - echo "GPU is OK" - } - } - } - } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') @@ -473,6 +465,7 @@ def Build_CK(Map conf=[:]){ show_node_info() env.HSA_ENABLE_SDMA=0 + env.DOCKER_BUILDKIT=1 checkout scm def image = getDockerImageName() @@ -487,25 +480,35 @@ def Build_CK(Map conf=[:]){ if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" def variant = env.STAGE_NAME def retimage def navi_node = 0 + def mi300_node = 0 - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) withDockerContainer(image: image, args: dockerOpts) { timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo | tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ + sh 'rocminfo | tee rocminfo.log' + if ( !runShell('grep -n "gfx" rocminfo.log') ){ throw new Exception ("GPU not found") } else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ navi_node = 1 + echo "This is a Navi node" + } + if ( runShell('grep -n "gfx942" rocminfo.log') ){ + mi300_node = 1 + echo "This is MI300 node" } } } @@ -514,23 +517,6 @@ def Build_CK(Map conf=[:]){ echo "The job was cancelled or aborted" throw e } - catch(Exception ex) { - retimage = docker.build("${image}", dockerArgs + " --no-cache .") - withDockerContainer(image: image, args: dockerOpts) { - timeout(time: 5, unit: 'MINUTES'){ - sh 'PATH="/opt/rocm/opencl/bin:/opt/rocm/opencl/bin/x86_64:$PATH" clinfo |tee clinfo.log' - if ( runShell('grep -n "Number of devices:.*. 0" clinfo.log') ){ - throw new Exception ("GPU not found") - } - else{ - echo "GPU is OK" - } - if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){ - navi_node = 1 - } - } - } - } withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { @@ -544,8 +530,8 @@ def Build_CK(Map conf=[:]){ sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' stash "ckProfiler.tar.gz" } - if (params.RUN_FULL_QA){ - // build deb packages + if (params.RUN_FULL_QA && mi300_node == 0 ){ + // build deb packages for all MI100/200/300 targets and prepare to export sh 'make -j package' archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb' @@ -610,7 +596,7 @@ def process_results(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) } @@ -678,6 +664,10 @@ pipeline { name: "BUILD_DOCKER", defaultValue: false, description: "Force building docker image (default: false), set to true if docker image needs to be updated.") + string( + name: 'USE_CUSTOM_DOCKER', + defaultValue: '', + description: 'If you want to use a custom docker image, please scecify it here (default: OFF).') string( name: 'ROCMVERSION', defaultValue: '6.0', @@ -828,6 +818,26 @@ pipeline { cleanWs() } } + stage("Build CK and run Tests on MI300") + { + when { + beforeAgent true + expression { params.RUN_FULL_QA.toBoolean() } + } + agent{ label rocmnode("gfx942") } + environment{ + setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx942" -DCMAKE_CXX_FLAGS=" -O3 " """ + execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ + cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ + -DGPU_TARGETS="gfx942" \ + -DCMAKE_CXX_COMPILER="${build_compiler()}" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ + } + steps{ + Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + cleanWs() + } + } stage("Build CK and run Tests on MI100/MI200") { when { From 8eff4d62b669df7c34e1490d38520537c0178e2f Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Tue, 5 Mar 2024 19:08:43 -0600 Subject: [PATCH 18/36] Add host lib (#1134) * Format * Format * Format * Remove const * Use the right template * Format * Format * add row/col instances * Add missing file * fixed * Format * Updates * Format * fixed rrr layout * Format * Update test and embed modules * Restore older version * Update year * Set -fPIC * Format * Use double for isnan * rename host folder to codegen + minor fix * add codegen CI test * add option to build components without building CK * fix the groovy syntax * fix typo * use the correct function for the codegen stage --------- Co-authored-by: Jing Zhang Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin --- Jenkinsfile | 51 +- cmake/Embed.cmake | 238 +++++ codegen/CMakeLists.txt | 49 + codegen/driver/main.cpp | 71 ++ .../ck/host/device_gemm_multiple_d.hpp | 42 + .../host/device_gemm_multiple_d/operation.hpp | 42 + .../host/device_gemm_multiple_d/problem.hpp | 39 + codegen/include/ck/host/headers.hpp | 18 + codegen/include/ck/host/operation/gemm.hpp | 49 + codegen/include/ck/host/stringutils.hpp | 104 +++ codegen/include/ck/host/types.hpp | 78 ++ codegen/include/ck/host/utils.hpp | 17 + codegen/src/device_gemm_multiple_d.cpp | 33 + ...gemm_multiple_d_operation_xdl_cshuffle.cpp | 295 ++++++ codegen/src/headers.cpp | 17 + codegen/src/types.cpp | 63 ++ codegen/src/utils.cpp | 21 + codegen/test/CMakeLists.txt | 11 + codegen/test/gemm_multiple_d.cpp | 185 ++++ codegen/test/include/test.hpp | 848 ++++++++++++++++++ codegen/test/rtc/CMakeLists.txt | 6 + .../test/rtc/include/rtc/compile_kernel.hpp | 27 + codegen/test/rtc/include/rtc/hip.hpp | 78 ++ codegen/test/rtc/include/rtc/kernel.hpp | 62 ++ codegen/test/rtc/include/rtc/manage_ptr.hpp | 55 ++ codegen/test/rtc/include/rtc/tmp_dir.hpp | 24 + codegen/test/rtc/src/compile_kernel.cpp | 95 ++ codegen/test/rtc/src/hip.cpp | 102 +++ codegen/test/rtc/src/kernel.cpp | 121 +++ codegen/test/rtc/src/tmp_dir.cpp | 48 + .../device_gemm_multiple_d_xdl_cshuffle.hpp | 335 +++++-- .../gpu/grid/block_to_ctile_map.hpp | 53 +- .../gridwise_gemm_multiple_d_xdl_cshuffle.hpp | 10 +- 33 files changed, 3170 insertions(+), 117 deletions(-) create mode 100644 cmake/Embed.cmake create mode 100644 codegen/CMakeLists.txt create mode 100644 codegen/driver/main.cpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d.hpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d/operation.hpp create mode 100644 codegen/include/ck/host/device_gemm_multiple_d/problem.hpp create mode 100644 codegen/include/ck/host/headers.hpp create mode 100644 codegen/include/ck/host/operation/gemm.hpp create mode 100644 codegen/include/ck/host/stringutils.hpp create mode 100644 codegen/include/ck/host/types.hpp create mode 100644 codegen/include/ck/host/utils.hpp create mode 100644 codegen/src/device_gemm_multiple_d.cpp create mode 100644 codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp create mode 100644 codegen/src/headers.cpp create mode 100644 codegen/src/types.cpp create mode 100644 codegen/src/utils.cpp create mode 100644 codegen/test/CMakeLists.txt create mode 100644 codegen/test/gemm_multiple_d.cpp create mode 100644 codegen/test/include/test.hpp create mode 100644 codegen/test/rtc/CMakeLists.txt create mode 100644 codegen/test/rtc/include/rtc/compile_kernel.hpp create mode 100644 codegen/test/rtc/include/rtc/hip.hpp create mode 100644 codegen/test/rtc/include/rtc/kernel.hpp create mode 100644 codegen/test/rtc/include/rtc/manage_ptr.hpp create mode 100644 codegen/test/rtc/include/rtc/tmp_dir.hpp create mode 100644 codegen/test/rtc/src/compile_kernel.cpp create mode 100644 codegen/test/rtc/src/hip.cpp create mode 100644 codegen/test/rtc/src/kernel.cpp create mode 100644 codegen/test/rtc/src/tmp_dir.cpp diff --git a/Jenkinsfile b/Jenkinsfile index 3cac20fd34..abecb76408 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -264,18 +264,24 @@ def cmake_build(Map conf=[:]){ """) sh cmd3 } - - def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") // reduce parallelism when compiling, clang uses too much memory def nt = nthreads() - def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") + def cmd def execute_cmd = conf.get("execute_cmd", "") - - def cmd = conf.get("cmd", """ + if(!setup_args.contains("NO_CK_BUILD")){ + def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ") + def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}") + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} ${execute_cmd} """) + } + else{ + cmd = conf.get("cmd", """ + ${execute_cmd} + """) + } echo cmd @@ -667,7 +673,7 @@ pipeline { string( name: 'USE_CUSTOM_DOCKER', defaultValue: '', - description: 'If you want to use a custom docker image, please scecify it here (default: OFF).') + description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( name: 'ROCMVERSION', defaultValue: '6.0', @@ -712,6 +718,10 @@ pipeline { name: "RUN_PERFORMANCE_TESTS", defaultValue: false, description: "Run the performance tests (default: OFF)") + booleanParam( + name: "RUN_CODEGEN_TESTS", + defaultValue: true, + description: "Run the codegen tests (default: ON)") } environment{ dbuser = "${dbuser}" @@ -790,7 +800,34 @@ pipeline { } } } - + stage("Run Codegen Tests") + { + parallel + { + stage("Run Codegen Tests on MI100/MI200") + { + when { + beforeAgent true + expression { params.RUN_CODEGEN_TESTS.toBoolean() } + } + options { retry(2) } + agent{ label rocmnode("gfx908 || gfx90a")} + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cd ../codegen && rm -rf build && mkdir build && cd build && \ + cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx908;gfx90a" \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j check""" + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } + } + } stage("Build CK and run Tests") { parallel diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake new file mode 100644 index 0000000000..4bc638b446 --- /dev/null +++ b/cmake/Embed.cmake @@ -0,0 +1,238 @@ +##################################################################################### +# The MIT License (MIT) +# +# Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +##################################################################################### + +if(WIN32) + set(EMBED_USE RC CACHE STRING "Use RC or CArrays to embed data files") + set_property(CACHE EMBED_USE PROPERTY STRINGS "RC;CArrays") +else() + if(BUILD_SHARED_LIBS) + set(EMBED_USE LD CACHE STRING "Use LD or CArrays to embed data files") + else() + set(EMBED_USE CArrays CACHE STRING "Use LD or CArrays to embed data files") + endif() + set_property(CACHE EMBED_USE PROPERTY STRINGS "LD;CArrays") +endif() + +if(EMBED_USE STREQUAL "LD") + find_program(EMBED_LD ld REQUIRED) + find_program(EMBED_OBJCOPY objcopy REQUIRED) +endif() + +function(embed_wrap_string) + set(options) + set(oneValueArgs VARIABLE AT_COLUMN) + set(multiValueArgs) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + string(LENGTH ${${PARSE_VARIABLE}} string_length) + math(EXPR offset "0") + + while(string_length GREATER 0) + + if(string_length GREATER ${PARSE_AT_COLUMN}) + math(EXPR length "${PARSE_AT_COLUMN}") + else() + math(EXPR length "${string_length}") + endif() + + string(SUBSTRING ${${PARSE_VARIABLE}} ${offset} ${length} line) + set(lines "${lines}\n${line}") + + math(EXPR string_length "${string_length} - ${length}") + math(EXPR offset "${offset} + ${length}") + endwhile() + + set(${PARSE_VARIABLE} "${lines}" PARENT_SCOPE) +endfunction() + +function(generate_embed_source EMBED_NAME EMBED_DIR BASE_DIRECTORY) + set(options) + set(oneValueArgs) + set(multiValueArgs SYMBOLS FILES) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(RESOURCE_ID 100) + + list(LENGTH PARSE_SYMBOLS SYMBOLS_LEN) + list(LENGTH PARSE_FILES FILES_LEN) + if(NOT ${SYMBOLS_LEN} EQUAL ${FILES_LEN}) + message(FATAL_ERROR "Symbols and objects dont match: ${SYMBOLS_LEN} != ${FILES_LEN}") + endif() + math(EXPR LEN "${SYMBOLS_LEN} - 1") + + foreach(idx RANGE ${LEN}) + list(GET PARSE_SYMBOLS ${idx} SYMBOL) + list(GET PARSE_FILES ${idx} FILE) + file(RELATIVE_PATH BASE_NAME "${BASE_DIRECTORY}" ${FILE}) + if(EMBED_USE STREQUAL "RC") + string(TOUPPER "${SYMBOL}" SYMBOL) + string(APPEND FILE_IDS "#define IDR_${SYMBOL} ${RESOURCE_ID}\n") + file(TO_NATIVE_PATH "${FILE}" NATIVE_FILE) + string(REPLACE "\\" "\\\\" NATIVE_FILE "${NATIVE_FILE}") + string(APPEND RC_FILE_MAPPING "IDR_${SYMBOL} TEXTFILE \"${NATIVE_FILE}\"\n") + string(APPEND INIT_KERNELS "\n {\"${BASE_NAME}\", resource::read(IDR_${SYMBOL})},") + math(EXPR RESOURCE_ID "${RESOURCE_ID} + 1" OUTPUT_FORMAT DECIMAL) + else() + set(START_SYMBOL "_binary_${SYMBOL}_start") + set(LENGTH_SYMBOL "_binary_${SYMBOL}_length") + if(EMBED_USE STREQUAL "LD") + string(APPEND EXTERNS " +extern const char ${START_SYMBOL}[]; +extern const size_t _binary_${SYMBOL}_size; +const auto ${LENGTH_SYMBOL} = reinterpret_cast(&_binary_${SYMBOL}_size); +") + else() + string(APPEND EXTERNS " +extern const char ${START_SYMBOL}[]; +extern const size_t ${LENGTH_SYMBOL}; +") + endif() + string(APPEND INIT_KERNELS " + { \"${BASE_NAME}\", { ${START_SYMBOL}, ${LENGTH_SYMBOL}} },") + endif() + endforeach() + if(EMBED_USE STREQUAL "RC") + file(WRITE "${EMBED_DIR}/include/resource.h" " +#define TEXTFILE 256 + +${FILE_IDS} +") + file(WRITE "${EMBED_DIR}/resource.rc" " +#include \"resource.h\" + +${RC_FILE_MAPPING} +") + set(EXTERNS " +#include +#include \"resource.h\" + +namespace resource { +std::string_view read(int id) +{ + HMODULE handle = GetModuleHandle(nullptr); + HRSRC rc = FindResource(handle, MAKEINTRESOURCE(id), MAKEINTRESOURCE(TEXTFILE)); + HGLOBAL data = LoadResource(handle, rc); + return {static_cast(LockResource(data)), SizeofResource(handle, rc)}; +} +} +") + set(EMBED_FILES ${EMBED_DIR}/include/resource.h ${EMBED_DIR}/resource.rc) + endif() + file(WRITE "${EMBED_DIR}/include/${EMBED_NAME}.hpp" " +#include +#include +#include +std::unordered_map ${EMBED_NAME}(); +") + + file(WRITE "${EMBED_DIR}/${EMBED_NAME}.cpp" " +#include <${EMBED_NAME}.hpp> +${EXTERNS} +std::unordered_map ${EMBED_NAME}() +{ + static std::unordered_map result = {${INIT_KERNELS} + }; + return result; +} +") + list(APPEND EMBED_FILES ${EMBED_DIR}/${EMBED_NAME}.cpp ${EMBED_DIR}/include/${EMBED_NAME}.hpp) + set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE) +endfunction() + +function(embed_file FILE BASE_DIRECTORY) + message(STATUS " ${FILE}") + file(RELATIVE_PATH REL_FILE "${BASE_DIRECTORY}" ${FILE}) + string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL) + get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") + if(EMBED_USE STREQUAL "LD") + set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") + add_custom_command( + OUTPUT "${OUTPUT_FILE}" + COMMAND ${EMBED_LD} -r -o "${OUTPUT_FILE}" -z noexecstack --format=binary "${REL_FILE}" + COMMAND ${EMBED_OBJCOPY} --rename-section .data=.rodata,alloc,load,readonly,data,contents "${OUTPUT_FILE}" + WORKING_DIRECTORY "${BASE_DIRECTORY}" + DEPENDS "${FILE}" + VERBATIM) + set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE) + elseif(EMBED_USE STREQUAL "CArrays") + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS ${FILE}) + set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.cpp") + # reads source file contents as hex string + file(READ ${FILE} HEX_STRING HEX) + # wraps the hex string into multiple lines + embed_wrap_string(VARIABLE HEX_STRING AT_COLUMN 80) + # adds '0x' prefix and comma suffix before and after every byte respectively + string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1, " ARRAY_VALUES ${HEX_STRING}) + # removes trailing comma + string(REGEX REPLACE ", $" "" ARRAY_VALUES ${ARRAY_VALUES}) + file(WRITE "${OUTPUT_FILE}" " +#include +extern const char _binary_${OUTPUT_SYMBOL}_start[] = { ${ARRAY_VALUES} }; +extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SYMBOL}_start); +") + set(OUTPUT_FILE ${OUTPUT_FILE} PARENT_SCOPE) + endif() + set(OUTPUT_SYMBOL ${OUTPUT_SYMBOL} PARENT_SCOPE) +endfunction() + +function(add_embed_library EMBED_NAME) + set(options) + set(oneValueArgs RELATIVE) + set(multiValueArgs) + cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(EMBED_DIR ${CMAKE_CURRENT_BINARY_DIR}/embed/${EMBED_NAME}) + file(MAKE_DIRECTORY ${EMBED_DIR}) + message(STATUS "Embedding kernel files:") + foreach(FILE ${PARSE_UNPARSED_ARGUMENTS}) + embed_file(${FILE} ${PARSE_RELATIVE}) + list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) + list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) + endforeach() + message(STATUS "Generating embedding library '${EMBED_NAME}'") + generate_embed_source(${EMBED_NAME} ${EMBED_DIR} "${PARSE_RELATIVE}" SYMBOLS ${SYMBOLS} FILES ${PARSE_UNPARSED_ARGUMENTS}) + set(INTERNAL_EMBED_LIB embed_lib_${EMBED_NAME}) + if(EMBED_USE STREQUAL "LD") + add_library(${INTERNAL_EMBED_LIB} STATIC ${EMBED_FILES} ${OUTPUT_FILES}) + else() + add_library(${INTERNAL_EMBED_LIB} OBJECT ${EMBED_FILES}) + endif() + if(EMBED_USE STREQUAL "CArrays") + target_sources(${INTERNAL_EMBED_LIB} PRIVATE ${OUTPUT_FILES}) + endif() + target_include_directories(${INTERNAL_EMBED_LIB} PRIVATE "${EMBED_DIR}/include") + target_compile_options(${INTERNAL_EMBED_LIB} PRIVATE -Wno-reserved-identifier -Wno-extern-initializer -Wno-missing-variable-declarations) + set_target_properties(${INTERNAL_EMBED_LIB} PROPERTIES POSITION_INDEPENDENT_CODE On) + add_library(${EMBED_NAME} INTERFACE) + if(EMBED_USE STREQUAL "RC") + target_link_libraries(${EMBED_NAME} INTERFACE $) + elseif(EMBED_USE STREQUAL "LD") + target_link_libraries(${EMBED_NAME} INTERFACE ${INTERNAL_EMBED_LIB}) + else() + target_sources(${EMBED_NAME} INTERFACE $) + endif() + target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include") +endfunction() + diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt new file mode 100644 index 0000000000..72549c9a4e --- /dev/null +++ b/codegen/CMakeLists.txt @@ -0,0 +1,49 @@ +cmake_minimum_required(VERSION 3.16) +project(composable_kernel_host) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) + +find_package(ROCM) +include(ROCMInstallTargets) +include(ROCMTest) + +list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake) +include(Embed) +file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck/*.hpp) +message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") +message(STATUS "RELATIVE: ${CK_ROOT}/include") +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) + +add_definitions(-std=c++17) + +file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) +# TODO: Use object library +add_library(ck_host STATIC ${SOURCES}) +target_link_libraries(ck_host PRIVATE ck_headers) + +set_target_properties(ck_host PROPERTIES + LINKER_LANGUAGE CXX + POSITION_INDEPENDENT_CODE ON) + +target_include_directories(ck_host PUBLIC + $ +) + +add_executable(ck-template-driver driver/main.cpp) +target_link_libraries(ck-template-driver ck_host) + +rocm_install( + TARGETS ck_host ck_headers + EXPORT ck_hostTargets +) +rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}) + +if(BUILD_TESTING) +add_subdirectory(test) +endif() diff --git a/codegen/driver/main.cpp b/codegen/driver/main.cpp new file mode 100644 index 0000000000..dfd513106b --- /dev/null +++ b/codegen/driver/main.cpp @@ -0,0 +1,71 @@ + +#include +#include +#include +#include +#include +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/stringutils.hpp" + +using ck::host::Transform; + +struct Emitters +{ + std::unordered_map()>> m; + + template + void Register(const std::string& name) + { + m[name] = [] { + auto configs = T::CreateOperations(); + + return Transform(configs, [](const auto& ops) { return ToTuple(ops); }); + }; + } + + template + static std::string ToTuple(const T& ops) + { + auto templates = Transform( + ops, [](const auto& op) { return " " + op.ToSolution().ToTemplateString(); }); + return "std::tuple<\n" + ck::host::JoinStrings(templates, ",\n") + ">"; + } + + std::string Emit(const std::string& name) { return ck::host::JoinStrings(m.at(name)(), "\n"); } + + std::vector List() const + { + return Transform(m, [](auto&& p) { return p.first; }); + } +}; + +int main(int argc, const char* argv[]) +{ + std::string prog = argv[0]; + std::vector args(argv + 1, argv + argc); + Emitters e; + e.Register( + "DeviceGemmMultipleD_Xdl_CShuffle"); + + if(args.empty() or std::any_of(args.begin(), args.end(), [](auto arg) { + return arg == "-h" or arg == "--help"; + })) + { + std::cout << "USAGE:" << std::endl; + std::cout << " " << prog << " [TEMPLATE]" << std::endl; + std::cout << std::endl; + std::cout << "FLAGS:" << std::endl; + std::cout << " -h, --help Show help" << std::endl; + std::cout << std::endl; + std::cout << "TEMPLATES:" << std::endl; + for(auto x : e.List()) + std::cout << " " << x << std::endl; + std::cout << std::endl; + return 0; + } + + for(auto name : args) + std::cout << e.Emit(name) << std::endl; + + return 0; +} diff --git a/codegen/include/ck/host/device_gemm_multiple_d.hpp b/codegen/include/ck/host/device_gemm_multiple_d.hpp new file mode 100644 index 0000000000..88e040db53 --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string BElementOp = "ck::tensor_operation::element_wise::PassThrough"; + std::string CDEElementOp = "ck::Tuple<>"; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp new file mode 100644 index 0000000000..f9d39633ac --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/operation.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/operation/gemm.hpp" +#include "ck/host/device_gemm_multiple_d/problem.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Operation_Xdl_CShuffle +{ + static std::vector> CreateOperations(); + static std::vector CreateOperations(const Problem& prob); + TensorDesc A{}; + TensorDesc B{}; + DataType acc = DataType::Float; + DataType cs_type = DataType::Half; + std::vector Ds = {}; + TensorDesc E{}; + std::string a_elem_op = PassThrough; + std::string b_elem_op = PassThrough; + std::string cde_elem_op = Bilinear; + std::string gemm_specialization = "ck::tensor_operation::device::GemmSpecialization::Default"; + operation::TileDesc tile_desc{}; + operation::BlockTransferDesc a_block_transfer{}; + operation::BlockTransferDesc b_block_transfer{}; + operation::CShuffleDesc cshuffle{}; + operation::CBlockTransferDesc c_block_transfer{}; + + Solution ToSolution() const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp new file mode 100644 index 0000000000..f6dbc2b6e8 --- /dev/null +++ b/codegen/include/ck/host/device_gemm_multiple_d/problem.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +struct Problem +{ + std::size_t M = 0; + std::size_t N = 0; + std::size_t K = 0; + bool TransA = false; + bool TransB = false; + bool TransE = false; + std::vector DsTrans = {}; + DataType ADataType = DataType::Half; + DataType BDataType = DataType::Half; + DataType EDataType = DataType::Half; + std::vector DsDataType = {}; + std::string AElementOp = PassThrough; + std::string BElementOp = PassThrough; + std::string CDEElementOp = PassThrough; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp new file mode 100644 index 0000000000..3da05baaaf --- /dev/null +++ b/codegen/include/ck/host/headers.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +std::unordered_map GetHeaders(); + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/operation/gemm.hpp b/codegen/include/ck/host/operation/gemm.hpp new file mode 100644 index 0000000000..f587122b05 --- /dev/null +++ b/codegen/include/ck/host/operation/gemm.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +namespace host { +namespace operation { + +struct TileDesc +{ + int block_size = 0; + int m_per_block = 0; + int n_per_block = 0; + int k_per_block = 0; + int ak1 = 0; + int bk1 = 0; + int m_per_XDL = 0; + int n_per_XDL = 0; + int m_Xdl_per_wave = 0; + int n_Xdl_per_wave = 0; + int num_gemmk_prefetch_stage = 0; +}; +struct BlockTransferDesc +{ + std::string thread_cluster_length = ""; + std::string thread_cluster_arrange_order = ""; + std::string src_access_order = ""; + int src_vec_dim = 0; + int src_scalar_per_vector = 0; + int dst_scalar_per_vector_k1 = 0; + int lds_add_extra_dim = 0; +}; +struct CShuffleDesc +{ + int m_Xdl_per_wave_per_shuffle = 0; + int n_Xdl_per_wave_per_shuffle = 0; +}; +struct CBlockTransferDesc +{ + std::string cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl = ""; + int scalar_per_vector_n_wave_n_per_Xdl = 0; +}; + +} // namespace operation +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp new file mode 100644 index 0000000000..01374b86c8 --- /dev/null +++ b/codegen/include/ck/host/stringutils.hpp @@ -0,0 +1,104 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +template +std::string trim(const std::string& s, F f) +{ + auto start = std::find_if_not(s.begin(), s.end(), f); + auto last = std::find_if_not(s.rbegin(), std::string::const_reverse_iterator(start), f).base(); + return {start, last}; +} + +inline std::string trim(const std::string& s) +{ + return trim(s, [](unsigned char c) { return std::isspace(c); }); +} + +template +inline std::string JoinStrings(Strings strings, const std::string& delim) +{ + auto it = strings.begin(); + if(it == strings.end()) + return ""; + + auto nit = std::next(it); + return std::accumulate(nit, strings.end(), *it, [&](std::string x, std::string y) { + return std::move(x) + delim + std::move(y); + }); +} + +template +inline std::string +InterpolateString(const std::string& input, F f, std::string start = "${", std::string end = "}") +{ + std::string result = ""; + result.reserve(input.size()); + auto it = input.begin(); + while(it != input.end()) + { + auto next_start = std::search(it, input.end(), start.begin(), start.end()); + auto next_end = std::search(next_start, input.end(), end.begin(), end.end()); + result.append(it, next_start); + if(next_start == input.end()) + break; + if(next_end == input.end()) + { + throw std::runtime_error("Unbalanced brackets"); + } + auto r = f(next_start + start.size(), next_end); + result.append(r.begin(), r.end()); + it = next_end + end.size(); + } + return result; +} +inline std::string InterpolateString(const std::string& input, + const std::unordered_map& vars, + std::string start = "${", + std::string end = "}") +{ + return InterpolateString( + input, + [&](auto start_it, auto last_it) { + auto key = trim({start_it, last_it}); + auto it = vars.find(key); + if(it == vars.end()) + throw std::runtime_error("Unknown key: " + key); + return it->second; + }, + std::move(start), + std::move(end)); +} + +template +inline auto Transform(const Range& r, F f) -> std::vector +{ + std::vector result; + std::transform(r.begin(), r.end(), std::back_inserter(result), f); + return result; +} + +template +inline auto Transform(const Range1& r1, const Range2& r2, F f) + -> std::vector +{ + std::vector result; + assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); + std::transform(r1.begin(), r1.end(), r2.begin(), std::back_inserter(result), f); + return result; +} + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/types.hpp b/codegen/include/ck/host/types.hpp new file mode 100644 index 0000000000..23488a66d0 --- /dev/null +++ b/codegen/include/ck/host/types.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +namespace ck { +namespace host { + +struct Solution +{ + + Solution() = default; + Solution(std::string str, std::unordered_map values); + std::string ToTemplateString() const; + std::string GetTemplateParameter(const std::string& name) const; + template + T GetTemplateParameter(const std::string& name) const + { + T result; + std::stringstream ss(GetTemplateParameter(name)); + ss >> result; + return result; + } + + private: + std::string template_str; + std::unordered_map template_values; +}; + +enum class DataType +{ + Half, + Float, + Int8, + Int32 +}; + +std::string ToString(DataType dt); + +enum class Layout +{ + Row, + Column +}; + +std::string ToString(Layout dl); + +enum class GemmType +{ + Default +}; + +std::string ToString(GemmType gt); + +struct TensorDesc +{ + DataType element; + Layout layout; +}; + +std::string SequenceStr(const std::vector& v); + +std::string MakeTuple(const std::vector& v); + +template +const std::string S = SequenceStr({xs...}); + +constexpr const char* PassThrough = "ck::tensor_operation::element_wise::PassThrough"; +constexpr const char* Bilinear = "ck::tensor_operation::element_wise::Bilinear"; + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/utils.hpp b/codegen/include/ck/host/utils.hpp new file mode 100644 index 0000000000..e8785a456f --- /dev/null +++ b/codegen/include/ck/host/utils.hpp @@ -0,0 +1,17 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y); + +const std::unordered_set& get_xdlop_archs(); + +} // namespace host +} // namespace ck diff --git a/codegen/src/device_gemm_multiple_d.cpp b/codegen/src/device_gemm_multiple_d.cpp new file mode 100644 index 0000000000..ec25afc0f9 --- /dev/null +++ b/codegen/src/device_gemm_multiple_d.cpp @@ -0,0 +1,33 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +std::string Problem::GetIncludeHeader() const +{ + return "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp"; +} + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + if(get_xdlop_archs().count(arch) == 0) + return {}; + auto ops = ck::host::device_gemm_multiple_d::Operation_Xdl_CShuffle::CreateOperations(*this); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [&](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck \ No newline at end of file diff --git a/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp new file mode 100644 index 0000000000..9e397497ee --- /dev/null +++ b/codegen/src/device_gemm_multiple_d_operation_xdl_cshuffle.cpp @@ -0,0 +1,295 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include + +namespace ck { +namespace host { +namespace device_gemm_multiple_d { + +static std::string GetGemmSpec(const std::size_t m, + const std::size_t n, + const std::size_t k, + const std::size_t m_per_block, + const std::size_t n_per_block, + const std::size_t k_per_block) +{ + std::string spec = ""; + if(integer_divide_ceil(m, m_per_block) * m_per_block - m != 0) + spec += "M"; + if(integer_divide_ceil(n, n_per_block) * n_per_block - n != 0) + spec += "N"; + if(integer_divide_ceil(k, k_per_block) * k_per_block - k != 0) + spec += "K"; + if(spec == "") + return "ck::tensor_operation::device::GemmSpecialization::Default"; + + return "ck::tensor_operation::device::GemmSpecialization::" + spec + "Padding"; +} + +static Layout ToLayout(bool Trans) { return Trans ? Layout::Column : Layout::Row; } + +std::vector Operation_Xdl_CShuffle::CreateOperations(const Problem& prob) +{ + std::vector result; + + std::vector tile_descriptions = { + // clang-format off +// Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| NumGemmK| +// Size| Block| Block| Block| | | XDL| XDL| Per| Per| Prefetch| +// | | | | | | | | Wave| Wave| Stage| +// | | | | | | | | | | | + { 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, 1}, + { 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, 1}, + { 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, 1}, + { 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, 1}, + { 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, 1}, + { 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_rowmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector a_block_descriptions_colmajor = { + // clang-format off +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| +// Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + // clang-format on + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + {S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + {S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + }; + + std::vector b_block_descriptions_rowmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + { S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1}, + { S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1}, + // clang-format on + }; + + std::vector b_block_descriptions_colmajor = { + // clang-format off +// BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| +// ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| +// Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | +// | | | | | | | + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + { S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1}, + // clang-format on + }; + + std::vector cshuffle_descriptions = { + // clang-format off +// CShuffle| CShuffle| +// MXdlPerWave| NXdlPerWave| +// PerShuffle| PerShuffle| +// | | + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + { 1, 1}, + // clang-format on + }; + + std::vector c_block_descriptions = { + // clang-format off +// CBlockTransferClusterLengths| CBlockTransfer +// _MBlock_MWaveMPerXdl| ScalarPerVector +// _NBlock_NWaveNPerXdl| _NWaveNPerXdl +// | + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 4>, 8}, + { S<1, 16, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + { S<1, 32, 1, 8>, 8}, + // clang-format on + }; + + const auto a_block_descriptions = + prob.TransA ? a_block_descriptions_colmajor : a_block_descriptions_rowmajor; + const auto b_block_descriptions = + prob.TransB ? b_block_descriptions_colmajor : b_block_descriptions_rowmajor; + + assert(tile_descriptions.size() == a_block_descriptions.size()); + assert(tile_descriptions.size() == b_block_descriptions.size()); + assert(tile_descriptions.size() == cshuffle_descriptions.size()); + assert(tile_descriptions.size() == c_block_descriptions.size()); + + for(std::size_t i = 0; i < tile_descriptions.size(); i++) + { + Operation_Xdl_CShuffle x; + x.tile_desc = tile_descriptions[i]; + x.a_block_transfer = a_block_descriptions[i]; + x.b_block_transfer = b_block_descriptions[i]; + x.cshuffle = cshuffle_descriptions[i]; + x.c_block_transfer = c_block_descriptions[i]; + x.A = TensorDesc{prob.ADataType, ToLayout(prob.TransA)}; + x.B = TensorDesc{prob.BDataType, ToLayout(prob.TransB)}; + x.E = TensorDesc{prob.EDataType, ToLayout(prob.TransE)}; + x.Ds = Transform(prob.DsTrans, prob.DsDataType, [](auto trans, auto dt) { + return TensorDesc{dt, ToLayout(trans)}; + }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; + x.gemm_specialization = GetGemmSpec(prob.M, + prob.N, + prob.K, + x.tile_desc.m_per_block, + x.tile_desc.n_per_block, + x.tile_desc.k_per_block); + result.push_back(x); + } + return result; +} + +std::vector> Operation_Xdl_CShuffle::CreateOperations() +{ + std::vector problems; + for(bool TransA : {true, false}) + for(bool TransB : {true, false}) + { + Problem prob; + prob.TransA = TransA; + prob.TransB = TransB; + problems.push_back(prob); + } + return Transform(problems, [](const Problem& p) { return CreateOperations(p); }); +} + +static const char* const DeviceGemmMultipleD_Xdl_CShuffleTemplate = + "ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_CShuffle<${LayoutA}, ${LayoutB}, " + "${LayoutDs}, ${LayoutE}, ${ADataType}, ${BDataType}, ${AccDataType}, ${CShuffleDataType}, " + "${DsDataType}, ${EDataType}, ${AElementwiseOperation}, ${BElementwiseOperation}, " + "${CDEElementwiseOperation}, ${GemmSpecialization}, ${NumGemmkPrefetchStage}, ${BlockSize}, " + "${MPerBlock}, ${NPerBlock}, ${KPerBlock}, ${AK1}, ${BK1}, ${MPerXDL}, ${NPerXDL}, " + "${MXdlPerWave}, ${NXdlPerWave}, ${ABlockTransferThreadClusterLengths_AK0_M_AK1}, " + "${ABlockTransferThreadClusterArrangeOrder}, ${ABlockTransferSrcAccessOrder}, " + "${ABlockTransferSrcVectorDim}, ${ABlockTransferSrcScalarPerVector}, " + "${ABlockTransferDstScalarPerVector_AK1}, ${ABlockLdsExtraM}, " + "${BBlockTransferThreadClusterLengths_BK0_N_BK1}, ${BBlockTransferThreadClusterArrangeOrder}, " + "${BBlockTransferSrcAccessOrder}, ${BBlockTransferSrcVectorDim}, " + "${BBlockTransferSrcScalarPerVector}, ${BBlockTransferDstScalarPerVector_BK1}, " + "${BBlockLdsExtraN}, ${CShuffleMXdlPerWavePerShuffle}, ${CShuffleNXdlPerWavePerShuffle}, " + "${CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock}, " + "${CDEBlockTransferScalarPerVector_NPerBlock}>"; + +Solution Operation_Xdl_CShuffle::ToSolution() const +{ + std::unordered_map values = { + {"LayoutA", ToString(this->A.layout)}, + {"LayoutB", ToString(this->B.layout)}, + {"LayoutDs", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.layout); }))}, + {"LayoutE", ToString(this->E.layout)}, + {"ADataType", ToString(this->A.element)}, + {"BDataType", ToString(this->B.element)}, + {"AccDataType", ToString(this->acc)}, + {"CShuffleDataType", ToString(this->cs_type)}, + {"DsDataType", + MakeTuple(Transform(this->Ds, [](auto tensor) { return ToString(tensor.element); }))}, + {"EDataType", ToString(this->E.element)}, + {"AElementwiseOperation", this->a_elem_op}, + {"BElementwiseOperation", this->b_elem_op}, + {"CDEElementwiseOperation", this->cde_elem_op}, + {"GemmSpecialization", this->gemm_specialization}, + {"NumGemmkPrefetchStage", std::to_string(this->tile_desc.num_gemmk_prefetch_stage)}, + {"BlockSize", std::to_string(this->tile_desc.block_size)}, + {"MPerBlock", std::to_string(this->tile_desc.m_per_block)}, + {"NPerBlock", std::to_string(this->tile_desc.n_per_block)}, + {"KPerBlock", std::to_string(this->tile_desc.k_per_block)}, + {"AK1", std::to_string(this->tile_desc.ak1)}, + {"BK1", std::to_string(this->tile_desc.bk1)}, + {"MPerXDL", std::to_string(this->tile_desc.m_per_XDL)}, + {"NPerXDL", std::to_string(this->tile_desc.n_per_XDL)}, + {"MXdlPerWave", std::to_string(this->tile_desc.m_Xdl_per_wave)}, + {"NXdlPerWave", std::to_string(this->tile_desc.n_Xdl_per_wave)}, + {"ABlockTransferThreadClusterLengths_AK0_M_AK1", + this->a_block_transfer.thread_cluster_length}, + {"ABlockTransferThreadClusterArrangeOrder", + this->a_block_transfer.thread_cluster_arrange_order}, + {"ABlockTransferSrcAccessOrder", this->a_block_transfer.src_access_order}, + {"ABlockTransferSrcVectorDim", std::to_string(this->a_block_transfer.src_vec_dim)}, + {"ABlockTransferSrcScalarPerVector", + std::to_string(this->a_block_transfer.src_scalar_per_vector)}, + {"ABlockTransferDstScalarPerVector_AK1", + std::to_string(this->a_block_transfer.dst_scalar_per_vector_k1)}, + {"ABlockLdsExtraM", std::to_string(this->a_block_transfer.lds_add_extra_dim)}, + {"BBlockTransferThreadClusterLengths_BK0_N_BK1", + this->b_block_transfer.thread_cluster_length}, + {"BBlockTransferThreadClusterArrangeOrder", + this->b_block_transfer.thread_cluster_arrange_order}, + {"BBlockTransferSrcAccessOrder", this->b_block_transfer.src_access_order}, + {"BBlockTransferSrcVectorDim", std::to_string(this->b_block_transfer.src_vec_dim)}, + {"BBlockTransferSrcScalarPerVector", + std::to_string(this->b_block_transfer.src_scalar_per_vector)}, + {"BBlockTransferDstScalarPerVector_BK1", + std::to_string(this->b_block_transfer.dst_scalar_per_vector_k1)}, + {"BBlockLdsExtraN", std::to_string(this->b_block_transfer.lds_add_extra_dim)}, + {"CShuffleMXdlPerWavePerShuffle", + std::to_string(this->cshuffle.m_Xdl_per_wave_per_shuffle)}, + {"CShuffleNXdlPerWavePerShuffle", + std::to_string(this->cshuffle.n_Xdl_per_wave_per_shuffle)}, + {"CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock", + this->c_block_transfer.cluster_lengths_m_block_m_wave_m_per_Xdl_n_block_n_wave_n_per_Xdl}, + {"CDEBlockTransferScalarPerVector_NPerBlock", + std::to_string(this->c_block_transfer.scalar_per_vector_n_wave_n_per_Xdl)}, + }; + + return Solution{InterpolateString(DeviceGemmMultipleD_Xdl_CShuffleTemplate, values), + std::move(values)}; +} + +} // namespace device_gemm_multiple_d +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp new file mode 100644 index 0000000000..6fcb94cdbd --- /dev/null +++ b/codegen/src/headers.cpp @@ -0,0 +1,17 @@ +#include "ck/host/headers.hpp" +#include "ck_headers.hpp" + +namespace ck { +namespace host { + +const std::string config_header = ""; + +std::unordered_map GetHeaders() +{ + auto headers = ck_headers(); + headers.insert(std::make_pair("ck/config.h", config_header)); + return headers; +} + +} // namespace host +} // namespace ck \ No newline at end of file diff --git a/codegen/src/types.cpp b/codegen/src/types.cpp new file mode 100644 index 0000000000..d43df73f33 --- /dev/null +++ b/codegen/src/types.cpp @@ -0,0 +1,63 @@ +#include "ck/host/types.hpp" +#include "ck/host/stringutils.hpp" +#include +#include + +namespace ck { +namespace host { + +Solution::Solution(std::string str, std::unordered_map values) + : template_str(std::move(str)), template_values(std::move(values)) +{ +} + +std::string Solution::ToTemplateString() const { return this->template_str; } +std::string Solution::GetTemplateParameter(const std::string& name) const +{ + return this->template_values.at(name); +} + +std::string ToString(DataType dt) +{ + switch(dt) + { + case DataType::Float: return "float"; + case DataType::Half: return "ck::half_t"; + case DataType::Int8: return "int8_t"; + case DataType::Int32: return "int32_t"; + } + throw std::runtime_error("Incorrect data type"); +} + +std::string ToString(Layout dl) +{ + switch(dl) + { + case Layout::Row: return "ck::tensor_layout::gemm::RowMajor"; + case Layout::Column: return "ck::tensor_layout::gemm::ColumnMajor"; + } + throw std::runtime_error("Incorrect layout"); +} + +std::string ToString(GemmType gt) +{ + switch(gt) + { + case GemmType::Default: return "ck::tensor_operation::device::GemmSpecialization::Default"; + } + throw std::runtime_error("Incorrect gemm type"); +} + +std::string SequenceStr(const std::vector& v) +{ + return "ck::Sequence<" + + JoinStrings(Transform(v, [](int x) { return std::to_string(x); }), ", ") + ">"; +} + +std::string MakeTuple(const std::vector& v) +{ + return "ck::Tuple<" + JoinStrings(v, ", ") + ">"; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/utils.cpp b/codegen/src/utils.cpp new file mode 100644 index 0000000000..cd6700c489 --- /dev/null +++ b/codegen/src/utils.cpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/host/utils.hpp" + +namespace ck { +namespace host { + +std::size_t integer_divide_ceil(std::size_t x, std::size_t y) +{ + return (x + y - std::size_t{1}) / y; +} + +const std::unordered_set& get_xdlop_archs() +{ + static std::unordered_set supported_archs{"gfx90a", "gfx908", "gfx940", "gfx942"}; + return supported_archs; +} + +} // namespace host +} // namespace ck diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt new file mode 100644 index 0000000000..897cce1c94 --- /dev/null +++ b/codegen/test/CMakeLists.txt @@ -0,0 +1,11 @@ + +list(APPEND CMAKE_PREFIX_PATH /opt/rocm) +add_subdirectory(rtc) + +file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) +foreach(TEST_SRC ${TEST_SRCS}) +get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE) +rocm_add_test_executable(test_host_${BASE_NAME} ${TEST_SRC}) +target_link_libraries(test_host_${BASE_NAME} ck_rtc ck_host) +target_include_directories(test_host_${BASE_NAME} PUBLIC include()) +endforeach() diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp new file mode 100644 index 0000000000..17b659993a --- /dev/null +++ b/codegen/test/gemm_multiple_d.cpp @@ -0,0 +1,185 @@ +#include "ck/host/device_gemm_multiple_d/problem.hpp" +#include "ck/host/device_gemm_multiple_d/operation.hpp" +#include "ck/host/headers.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include +#include +#include +#include +#include +#include +#include + +using half = _Float16; +// using half = __fp16; + +std::vector get_headers_for_test() +{ + std::vector result; + auto hs = ck::host::GetHeaders(); + std::transform( + hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { + return {p.first, p.second}; + }); + return result; +} + +template +rtc::buffer generate_buffer(std::size_t n, std::size_t seed = 0) +{ + rtc::buffer result(n); + std::mt19937 gen(seed); + std::uniform_real_distribution dis(-1.0); + std::generate(result.begin(), result.end(), [&] { return dis(gen); }); + return result; +} + +template +bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01) +{ + return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) { + return fabs(x - y) < atol + rtol * fabs(y); + }); +} + +std::string classify(double x) +{ + switch(std::fpclassify(x)) + { + case FP_INFINITE: return "inf"; + case FP_NAN: return "nan"; + case FP_NORMAL: return "normal"; + case FP_SUBNORMAL: return "subnormal"; + case FP_ZERO: return "zero"; + default: return "unknown"; + } +} + +template +void print_classification(const Buffer& x) +{ + std::unordered_set result; + for(const auto& i : x) + result.insert(classify(i)); + for(const auto& c : result) + std::cout << c << ", "; + std::cout << std::endl; +} + +template +void print_statistics(const Buffer& x) +{ + std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", "; + std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", "; + double num_elements = x.size(); + auto mean = + std::accumulate(x.begin(), x.end(), double{0.0}, std::plus{}) / num_elements; + auto stddev = std::sqrt( + std::accumulate(x.begin(), + x.end(), + double{0.0}, + [&](double r, double v) { return r + std::pow((v - mean), 2.0); }) / + num_elements); + std::cout << "Mean: " << mean << ", "; + std::cout << "StdDev: " << stddev << "\n"; +} + +template +void print_preview(const Buffer& x) +{ + if(x.size() <= 10) + { + std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; }); + } + else + { + std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; }); + std::cout << "..., "; + std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; }); + } + std::cout << std::endl; +} + +template +struct check_all +{ + rtc::buffer data{}; + bool operator()(const rtc::buffer& x) + { + if(data.empty()) + { + data = x; + return true; + } + if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); })) + return false; + return allclose(data, x); + } +}; + +template +auto report(const Solution& solution, bool pass) +{ + return test::make_predicate(solution.ToTemplateString(), [=] { return pass; }); +} + +const std::string gemm_compile_check = R"__ck__( +#include <${include}> + +extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_t* c) { + using G = ${template}; + constexpr auto desc = ${template}::make_descriptor(ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${k})), + ck::make_naive_tensor_descriptor(ck::make_tuple(${n}, ${k}), ck::make_tuple(1, ${n})), + ck::make_tuple(), + ck::make_naive_tensor_descriptor_packed(ck::make_tuple(${m}, ${n}))); + + static_assert(desc.IsValid(), "Invalid ck gemm."); + + if constexpr(desc.IsValid()) + { + ${template}::Run(desc, + a, + b, + ck::make_tuple(), + c); + } +} + +)__ck__"; + +TEST_CASE(test_problem_kernel) +{ + ck::host::device_gemm_multiple_d::Problem prob; + prob.M = 1024; + prob.N = 1024; + prob.K = 1024; + check_all check; + auto a = to_gpu(generate_buffer(1024 * 1024, 0)); + auto b = to_gpu(generate_buffer(1024 * 1024, 1)); + auto c = to_gpu(generate_buffer(1024 * 1024, 2)); + + for(auto solution : prob.GetSolutions("gfx90a")) + { + auto src = ck::host::InterpolateString(gemm_compile_check, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); + rtc::compile_options options; + options.kernel_name = "f"; + auto k = rtc::compile_kernel(srcs, options); + auto block_size = solution.GetTemplateParameter("BlockSize"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) * + ck::host::integer_divide_ceil(prob.N, n_per_block); + k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data()); + CHECK(report(solution, check(rtc::from_gpu(c)))); + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/include/test.hpp b/codegen/test/include/test.hpp new file mode 100644 index 0000000000..c3e38d6002 --- /dev/null +++ b/codegen/test/include/test.hpp @@ -0,0 +1,848 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include +#endif + +#ifndef MIGRAPHX_GUARD_TEST_TEST_HPP +#define MIGRAPHX_GUARD_TEST_TEST_HPP + +namespace test { +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_BINARY_OPERATORS(m) \ + m(==, equal) \ + m(!=, not_equal) \ + m(<=, less_than_equal) \ + m(>=, greater_than_equal) \ + m(<, less_than) \ + m(>, greater_than) \ + m(and, and_op) \ + m(or, or_op) +// clang-format on + +// clang-format off +// NOLINTNEXTLINE +#define TEST_FOREACH_UNARY_OPERATORS(m) \ + m(not, not_op) +// clang-format on + +// NOLINTNEXTLINE +#define TEST_EACH_BINARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x, U&& y) \ + { \ + return x op y; \ + } \ + }; + +// NOLINTNEXTLINE +#define TEST_EACH_UNARY_OPERATOR_OBJECT(op, name) \ + struct name \ + { \ + static std::string as_string() { return #op; } \ + template \ + static decltype(auto) call(T&& x) \ + { \ + return op x; \ + } \ + }; + +TEST_FOREACH_BINARY_OPERATORS(TEST_EACH_BINARY_OPERATOR_OBJECT) +TEST_FOREACH_UNARY_OPERATORS(TEST_EACH_UNARY_OPERATOR_OBJECT) + +struct nop +{ + static std::string as_string() { return ""; } + template + static auto call(T&& x) + { + return static_cast(x); + } +}; + +struct function +{ + static std::string as_string() { return ""; } + template + static decltype(auto) call(T&& x) + { + return x(); + } +}; + +template +Stream& stream_range(Stream& s, Iterator start, Iterator last); + +template +inline Stream& operator<<(Stream& s, std::nullptr_t) +{ + s << "nullptr"; + return s; +} + +template {}>::type> +inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.begin(), v.end())) +{ + s << "{ "; + stream_range(s, v.begin(), v.end()); + s << "}"; + return s; +} + +template +inline Stream& stream_range(Stream& s, Iterator start, Iterator last) +{ + if(start != last) + { + s << *start; + std::for_each(std::next(start), last, [&](auto&& x) { s << ", " << x; }); + } + return s; +} + +template +const T& get_value(const T& x) +{ + return x; +} + +template +struct lhs_expression; + +template +lhs_expression make_lhs_expression(T&& lhs); + +template +lhs_expression make_lhs_expression(T&& lhs, Operator); + +// NOLINTNEXTLINE +#define TEST_EXPR_BINARY_OPERATOR(op, name) \ + template \ + auto operator op(const V& rhs2) const \ + { \ + return make_expression(*this, rhs2, name{}); /* NOLINT */ \ + } + +// NOLINTNEXTLINE +#define TEST_EXPR_UNARY_OPERATOR(op, name) \ + auto operator op() const { return make_lhs_expression(lhs, name{}); /* NOLINT */ } + +template +struct expression +{ + T lhs; + U rhs; + + friend std::ostream& operator<<(std::ostream& s, const expression& self) + { + s << self.lhs << " " << Operator::as_string() << " " << self.rhs; + return s; + } + + friend decltype(auto) get_value(const expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs), get_value(rhs)); }; + + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) +}; + +// TODO: Remove rvalue references +template +expression make_expression(T&& rhs, U&& lhs, Operator) +{ + return {std::forward(rhs), std::forward(lhs)}; +} + +// TODO: Remove rvalue reference +template +lhs_expression make_lhs_expression(T&& lhs) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +lhs_expression make_lhs_expression(T&& lhs, Operator) +{ + return lhs_expression{std::forward(lhs)}; +} + +template +struct lhs_expression +{ + T lhs; + explicit lhs_expression(T e) : lhs(e) {} + + friend std::ostream& operator<<(std::ostream& s, const lhs_expression& self) + { + std::string op = Operator::as_string(); + if(not op.empty()) + s << Operator::as_string() << " "; + s << self.lhs; + return s; + } + + friend decltype(auto) get_value(const lhs_expression& e) { return e.value(); } + + decltype(auto) value() const { return Operator::call(get_value(lhs)); } + + TEST_FOREACH_BINARY_OPERATORS(TEST_EXPR_BINARY_OPERATOR) + TEST_FOREACH_UNARY_OPERATORS(TEST_EXPR_UNARY_OPERATOR) + +// NOLINTNEXTLINE +#define TEST_LHS_REOPERATOR(op) \ + template \ + auto operator op(const U& rhs) const \ + { \ + return make_lhs_expression(lhs op rhs); \ + } + TEST_LHS_REOPERATOR(+) + TEST_LHS_REOPERATOR(-) + TEST_LHS_REOPERATOR(*) + TEST_LHS_REOPERATOR(/) + TEST_LHS_REOPERATOR(%) + TEST_LHS_REOPERATOR(&) + TEST_LHS_REOPERATOR(|) + TEST_LHS_REOPERATOR(^) +}; + +template +struct predicate +{ + std::string msg; + F f; + + friend std::ostream& operator<<(std::ostream& s, const predicate& self) + { + s << self.msg; + return s; + } + + decltype(auto) operator()() const { return f(); } + + operator decltype(auto)() const { return f(); } +}; + +template +auto make_predicate(const std::string& msg, F f) +{ + return make_lhs_expression(predicate{msg, f}, function{}); +} + +inline std::string as_string(bool x) +{ + if(x) + return "true"; + return "false"; +} + +template +std::string as_string(const T& x) +{ + std::stringstream ss; + ss << x; + return ss.str(); +} + +template +std::string as_string(Iterator start, Iterator last) +{ + std::stringstream ss; + stream_range(ss, start, last); + return ss.str(); +} + +template +auto make_function(const std::string& name, F f) +{ + return [=](auto&&... xs) { + std::vector args = {as_string(xs)...}; + return make_predicate(name + "(" + as_string(args.begin(), args.end()) + ")", + [=] { return f(xs...); }); + }; +} + +struct capture +{ + template + auto operator->*(const T& x) const + { + return make_lhs_expression(x); + } + + template + auto operator->*(const lhs_expression& x) const + { + return x; + } +}; + +enum class color +{ + reset = 0, + bold = 1, + underlined = 4, + fg_red = 31, + fg_green = 32, + fg_yellow = 33, + fg_blue = 34, + fg_default = 39, + bg_red = 41, + bg_green = 42, + bg_yellow = 43, + bg_blue = 44, + bg_default = 49 +}; +inline std::ostream& operator<<(std::ostream& os, const color& c) +{ +#ifndef _WIN32 + static const bool use_color = isatty(STDOUT_FILENO) != 0; + if(use_color) + return os << "\033[" << static_cast(c) << "m"; +#else + (void)c; +#endif + return os; +} + +inline std::atomic& failures() +{ + // NOLINTNEXTLINE + static std::atomic f = 0; + return f; +} + +template +void failed(T x, const char* msg, const char* func, const char* file, int line, F f) +{ + if(not bool(x.value())) + { + failures()++; + std::cout << func << std::endl; + std::cout << file << ":" << line << ":" << std::endl; + std::cout << color::bold << color::fg_red << " FAILED: " << color::reset << msg << " " + << "[ " << x << " ]" << std::endl; + f(); + } +} + +template +bool throws(F f) +{ + try + { + f(); + return false; + } + catch(...) + { + return true; + } +} + +template +bool throws(F f, const std::string& msg = "") +{ + try + { + f(); + return false; + } + catch(const Exception& ex) + { + return std::string(ex.what()).find(msg) != std::string::npos; + } +} + +template +auto within_abs(T px, U py, double ptol = 1e-6f) +{ + return make_function("near", [](auto x, auto y, auto tol) { return std::abs(x - y) < tol; })( + px, py, ptol); +} + +// This implements the basic globbing algorithm where `*` matches any number +// of characters(including none) and `?` matches any single character. It +// doesnt support character classes. +// +// This is a simple recursive implementation that scans the string where the +// string and pattern matches. When a `*` is found in the pattern, the +// `glob_match` function is called recursively to compare the rest of the +// pattern to the rest of the string. If the recursive call returns true, +// then we have a match. However, if it returns false, then we advance one +// character and call the recusrsive call again. This is referred to as a +// star-loop, which will consume zero or more characters. +// +// This simple recursive implementation works well for short string and +// patterns with few stars. First, it is unlikely to use many stars to glob +// test names. Secondly, using many stars is still signficantly faster than +// using the equivalent std::regex, which has a much slower time complexity. +template +bool glob_match(Iterator1 start, Iterator1 last, Iterator2 pattern_start, Iterator2 pattern_last) +{ + std::tie(start, pattern_start) = + std::mismatch(start, last, pattern_start, pattern_last, [](auto c, auto m) { + if(m == '?') + return true; + // We need a loop for star, so bail and handle the loop below + if(m == '*') + return false; + return c == m; + }); + // If there is no more pattern then return true if there is no more string to match + if(pattern_start == pattern_last) + return start == last; + // If the pattern is not a star then its a mismatch + if(*pattern_start != '*') + return false; + // Multiple stars are the same as a single star so skip over multiple stars + pattern_start = std::find_if(pattern_start, pattern_last, [](auto c) { return c != '*'; }); + // If the star is at the end then return true + if(pattern_start == pattern_last) + return true; + // star-loop: match the rest of the pattern and text + while(not glob_match(start, last, pattern_start, pattern_last) and start != last) + start++; + // If the string is empty then it means a match was never found + return start != last; +} + +using string_map = std::unordered_map>; + +template +string_map generic_parse(std::vector as, Keyword keyword) +{ + string_map result; + + std::string flag; + for(auto&& x : as) + { + auto f = keyword(x); + if(f.empty()) + { + result[flag].push_back(x); + } + else + { + flag = f.front(); + result[flag]; // Ensure the flag exists + flag = f.back(); + } + } + return result; +} + +using test_case = std::function; + +inline auto& get_test_cases() +{ + // NOLINTNEXTLINE + static std::vector> cases; + return cases; +} + +inline void add_test_case(std::string name, test_case f) +{ + get_test_cases().emplace_back(std::move(name), std::move(f)); +} + +struct auto_register_test_case +{ + template + auto_register_test_case(const char* name, F f) noexcept + { + add_test_case(name, f); + } +}; + +struct failure_error +{ +}; + +[[noreturn]] inline void fail() { throw failure_error{}; } + +struct driver +{ + driver() + { + add_flag({"--help", "-h"}, "Show help"); + add_flag({"--list", "-l"}, "List all test cases"); + add_flag({"--continue", "-c"}, "Continue after failure"); + add_flag({"--quiet", "-q"}, "Don't print out extra output"); + } + struct argument + { + std::vector flags = {}; + std::string help = ""; + int nargs = 1; + }; + + void add_arg(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 1}); + } + + void add_flag(const std::vector& flags, const std::string& help = "") + { + arguments.push_back(argument{flags, help, 0}); + } + + static void wrap(std::ostream& os, + const std::string& text, + const std::string& prefix = "", + unsigned int line_length = 80) + { + std::istringstream iss(text); + std::string line = prefix; + do + { + std::string word; + iss >> word; + if(line.length() + word.length() > line_length) + { + os << line << std::endl; + line = prefix; + } + line += word + " "; + } while(iss); + if(not line.empty()) + os << line << std::endl; + } + + void show_help(const std::string& exe) const + { + const std::string prefix = " "; + std::cout << std::endl; + std::cout << color::fg_yellow << "USAGE:" << color::reset << std::endl; + std::cout << " "; + std::cout << exe << " ... " << std::endl; + std::cout << std::endl; + + std::cout << color::fg_yellow << "ARGS:" << color::reset << std::endl; + std::cout << " "; + std::cout << color::fg_green << "..." << color::reset; + std::cout << std::endl; + + wrap(std::cout, + "Test cases to run. A test case can be either the exact test case name or a glob. A " + "glob expression uses a '*' to select zero or more characters or a '?' to select any " + "single character.", + prefix + prefix); + + std::cout << std::endl; + std::cout << color::fg_yellow << "OPTIONS:" << color::reset << std::endl; + for(auto&& arg : arguments) + { + std::cout << color::fg_green; + std::string arg_prefix = prefix; + for(const std::string& a : arg.flags) + { + std::cout << arg_prefix; + std::cout << a; + arg_prefix = ", "; + } + std::cout << color::reset << std::endl; + wrap(std::cout, arg.help, prefix + prefix); + } + } + + std::ostream& out() const + { + struct null_buffer : std::streambuf + { + virtual int overflow(int c) override { return c; } + }; + static null_buffer buffer; + static std::ostream null_stream(&buffer); + if(quiet) + return null_stream; + return std::cout; + } + + string_map parse(int argc, const char* argv[]) const + { + std::vector args(argv + 1, argv + argc); + string_map keys; + for(auto&& arg : arguments) + { + for(auto&& flag : arg.flags) + { + keys[flag] = {arg.flags.front()}; + if(arg.nargs == 0) + keys[flag].push_back(""); + } + } + auto result = generic_parse(args, [&](auto&& s) -> std::vector { + if(keys.count(s) > 0) + return keys[s]; + else + return {}; + }); + result["__exe__"].push_back(argv[0]); + return result; + } + + static std::string create_command(const string_map& args) + { + std::stringstream ss; + ss << args.at("__exe__").front(); + if(args.count("") > 0) + { + for(auto&& arg : args.at("")) + ss << " \"" << arg << "\""; + } + for(auto&& p : args) + { + if(p.first == "__exe__") + continue; + if(p.first.empty()) + continue; + ss << " " << p.first; + for(auto&& arg : p.second) + ss << " \"" << arg << "\""; + } + return ss.str(); + } + + static std::string fork(const std::string& name, string_map args) + { + std::string msg; + args[""] = {name}; + args.erase("--continue"); + args["--quiet"]; + auto cmd = create_command(args); + auto r = std::system(cmd.c_str()); // NOLINT + if(r != 0) + msg = "Exited with " + std::to_string(r); + return msg; + } + + static std::vector> glob_tests(const std::string& pattern) + { + std::vector> result; + std::copy_if(get_test_cases().begin(), + get_test_cases().end(), + std::back_inserter(result), + [&](auto&& p) { + return glob_match( + p.first.begin(), p.first.end(), pattern.begin(), pattern.end()); + }); + return result; + } + + void run_test_case(const std::string& name, const test_case& f, const string_map& args) + { + ran++; + out() << color::fg_green << "[ RUN ] " << color::reset << color::bold << name + << color::reset << std::endl; + std::string msg; + auto start = std::chrono::steady_clock::now(); + if(args.count("--continue") > 0) + { + msg = fork(name, args); + } + else + { + try + { + failures() = 0; + f(); + } + // cppcheck-suppress migraphx-EmptyCatchStatement + catch(const failure_error&) + { + } + } + auto finish = std::chrono::steady_clock::now(); + auto elapsed_ms = + std::chrono::duration_cast>(finish - start) + .count(); + if(msg.empty() and failures() != 0) + { + if(failures() == 1) + msg = "Test failure"; + else + msg = std::to_string(failures()) + " test failures"; + } + if(msg.empty()) + { + out() << color::fg_green << "[ COMPLETE ] " << color::reset; + } + else + { + failed.push_back(name); + out() << color::fg_red << "[ FAILED ] " << color::reset; + } + out() << color::bold << name << color::reset; + out() << color::fg_blue << " (" << elapsed_ms << "ms)" << color::reset; + if(not msg.empty()) + out() << ": " << color::fg_yellow << msg << color::reset; + out() << std::endl; + } + + void run(int argc, const char* argv[]) + { + auto args = parse(argc, argv); + if(args.count("--help") > 0) + { + show_help(args.at("__exe__").front()); + return; + } + if(args.count("--list") > 0) + { + for(auto&& tc : get_test_cases()) + out() << tc.first << std::endl; + return; + } + + if(args.count("--quiet") > 0) + quiet = true; + + auto cases = args[""]; + if(cases.empty()) + { + for(auto&& tc : get_test_cases()) + run_test_case(tc.first, tc.second, args); + } + else + { + std::unordered_map m(get_test_cases().begin(), + get_test_cases().end()); + + for(auto&& iname : cases) + { + std::vector> found_cases; + for(auto&& pattern : get_case_names(iname)) + { + auto f = m.find(pattern); + if(f == m.end()) + { + found_cases = glob_tests(pattern); + } + else + { + found_cases.push_back(*f); + } + } + if(found_cases.empty()) + { + out() << color::fg_red << "[ ERROR ] Test case '" << iname << "' not found." + << color::reset << std::endl; + failed.push_back(iname); + } + for(auto&& p : found_cases) + run_test_case(p.first, p.second, args); + } + } + out() << color::fg_green << "[==========] " << color::fg_yellow << ran << " tests ran" + << color::reset << std::endl; + if(not failed.empty()) + { + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << failed.size() + << " tests failed" << color::reset << std::endl; + for(auto&& name : failed) + out() << color::fg_red << "[ FAILED ] " << color::fg_yellow << name + << color::reset << std::endl; + std::exit(1); + } + } + + std::function(const std::string&)> get_case_names = + [](const std::string& name) -> std::vector { return {name}; }; + std::vector arguments = {}; + std::vector failed = {}; + std::size_t ran = 0; + bool quiet = false; +}; + +inline void run(int argc, const char* argv[]) +{ + driver d{}; + d.run(argc, argv); +} + +} // namespace test + +// NOLINTNEXTLINE +#define TEST_CAPTURE(...) test::capture{}->*__VA_ARGS__ + +// NOLINTNEXTLINE +#define CHECK(...) \ + test::failed( \ + TEST_CAPTURE(__VA_ARGS__), #__VA_ARGS__, __PRETTY_FUNCTION__, __FILE__, __LINE__, [] {}) + +// NOLINTNEXTLINE +#define EXPECT(...) \ + test::failed(TEST_CAPTURE(__VA_ARGS__), \ + #__VA_ARGS__, \ + __PRETTY_FUNCTION__, \ + __FILE__, \ + __LINE__, \ + &test::fail) +// NOLINTNEXTLINE +#define STATUS(...) EXPECT((__VA_ARGS__) == 0) + +// NOLINTNEXTLINE +#define TEST_CAT(x, ...) TEST_PRIMITIVE_CAT(x, __VA_ARGS__) +// NOLINTNEXTLINE +#define TEST_PRIMITIVE_CAT(x, ...) x##__VA_ARGS__ + +// NOLINTNEXTLINE +#define TEST_CASE_REGISTER(...) \ + static test::auto_register_test_case TEST_CAT(register_test_case_, __LINE__) = \ + test::auto_register_test_case(#__VA_ARGS__, &__VA_ARGS__); + +// NOLINTNEXTLINE +#define TEST_CASE(...) \ + void __VA_ARGS__(); \ + TEST_CASE_REGISTER(__VA_ARGS__) \ + void __VA_ARGS__() + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wglobal-constructors" +#endif + +#endif diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt new file mode 100644 index 0000000000..441e60ca88 --- /dev/null +++ b/codegen/test/rtc/CMakeLists.txt @@ -0,0 +1,6 @@ + +find_package(hip) +file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) +add_library(ck_rtc ${RTC_SOURCES}) +target_include_directories(ck_rtc PUBLIC include) +target_link_libraries(ck_rtc PUBLIC hip::host) diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp new file mode 100644 index 0000000000..5a4a4b0dd6 --- /dev/null +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -0,0 +1,27 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_COMPILE_KERNEL + +#include +#include +#include + +namespace rtc { + +struct src_file +{ + std::filesystem::path path; + std::string_view content; +}; + +struct compile_options +{ + std::string flags = ""; + std::string kernel_name = "main"; +}; + +kernel compile_kernel(const std::vector& src, + compile_options options = compile_options{}); + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/hip.hpp b/codegen/test/rtc/include/rtc/hip.hpp new file mode 100644 index 0000000000..6b523382dc --- /dev/null +++ b/codegen/test/rtc/include/rtc/hip.hpp @@ -0,0 +1,78 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_HIP + +#include +#include +#include + +namespace rtc { + +template +struct buffer +{ + buffer() : ptr(), n(0) {} + buffer(std::shared_ptr p, std::size_t sz) : ptr(p), n(sz) {} + buffer(std::shared_ptr p, std::size_t sz) + : ptr(std::reinterpret_pointer_cast(p)), n(sz) + { + } + explicit buffer(std::size_t sz) : ptr(new T[sz]), n(sz) {} + T* begin() { return data(); } + T* end() { return data() + size(); } + const T* begin() const { return data(); } + const T* end() const { return data() + size(); } + + T& front() { return data()[0]; } + T& back() { return data()[size() - 1]; } + T& operator[](std::size_t i) { return data()[i]; } + T& at(std::size_t i) + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + + const T& front() const { return data()[0]; } + const T& back() const { return data()[size() - 1]; } + const T& operator[](std::size_t i) const { return data()[i]; } + const T& at(std::size_t i) const + { + if(i >= size()) + throw std::runtime_error("Out of bounds"); + return data()[i]; + } + const T* data() const { return ptr.get(); } + T* data() { return ptr.get(); } + + std::size_t size() const { return n; } + std::size_t bytes() const { return size() * sizeof(T); } + + bool empty() const { return size() == 0; } + + private: + std::shared_ptr ptr; + std::size_t n; +}; + +std::string get_device_name(); +std::string hip_error(int error); + +std::shared_ptr allocate_gpu(std::size_t sz, bool host = false); +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host = false); +std::shared_ptr read_from_gpu(const void* x, std::size_t sz); + +template +buffer to_gpu(const buffer& input) +{ + return {write_to_gpu(input.data(), input.bytes()), input.size()}; +} + +template +buffer from_gpu(const buffer& input) +{ + return {read_from_gpu(input.data(), input.bytes()), input.size()}; +} + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp new file mode 100644 index 0000000000..9f38e90416 --- /dev/null +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -0,0 +1,62 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_KERNEL + +#include +#include +#include +#include + +namespace rtc { + +struct kernel_argument +{ + template , + class = std::enable_if_t{}>> + kernel_argument(T&& x) : size(sizeof(U)), align(alignof(U)), data(&x) // NOLINT + { + } + std::size_t size; + std::size_t align; + void* data; +}; + +std::vector pack_args(const std::vector& args); + +struct kernel_impl; + +struct kernel +{ + kernel() = default; + kernel(const char* image, const std::string& name); + template + kernel(const std::vector& image, const std::string& name) + : kernel(reinterpret_cast(image.data()), name) + { + static_assert(sizeof(T) == 1, "Only byte types"); + } + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const; + + void launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const; + + template + auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const + { + return [=](auto&&... xs) { + launch(stream, global, local, std::vector{xs...}, zs...); + }; + } + + private: + std::shared_ptr impl; +}; +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/manage_ptr.hpp b/codegen/test/rtc/include/rtc/manage_ptr.hpp new file mode 100644 index 0000000000..92edf12628 --- /dev/null +++ b/codegen/test/rtc/include/rtc/manage_ptr.hpp @@ -0,0 +1,55 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_MANAGE_POINTER + +#include +#include + +namespace rtc { +template +struct manage_deleter +{ + template + void operator()(T* x) const + { + if(x != nullptr) + { + (void)f(x); + } + } +}; + +struct null_deleter +{ + template + void operator()(T*) const + { + } +}; + +template +using manage_ptr = std::unique_ptr>; + +template +struct element_type +{ + using type = typename T::element_type; +}; + +template +using remove_ptr = typename std:: + conditional_t{}, std::remove_pointer, element_type>::type; + +template +using shared = std::shared_ptr>; + +template +shared share(T p) +{ + return shared{std::move(p)}; +} + +#define RTC_MANAGE_PTR(T, F) rtc::manage_ptr, decltype(&F), &F> + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp new file mode 100644 index 0000000000..f0fd1f72bb --- /dev/null +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -0,0 +1,24 @@ +#ifndef GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR +#define GUARD_HOST_TEST_RTC_INCLUDE_RTC_TMP_DIR + +#include +#include + +namespace rtc { + +struct tmp_dir +{ + std::filesystem::path path; + tmp_dir(const std::string& prefix = ""); + + void execute(const std::string& cmd) const; + + tmp_dir(tmp_dir const&) = delete; + tmp_dir& operator=(tmp_dir const&) = delete; + + ~tmp_dir(); +}; + +} // namespace rtc + +#endif diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp new file mode 100644 index 0000000000..7ea55b9328 --- /dev/null +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -0,0 +1,95 @@ +#include "rtc/hip.hpp" +#include +#include +#include +#include +#include +#include + +namespace rtc { + +template +T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + std::ifstream is(filename, std::ios::binary | std::ios::ate); + if(nbytes == 0) + { + // if there is a non-zero offset and nbytes is not set, + // calculate size of remaining bytes to read + nbytes = is.tellg(); + if(offset > nbytes) + throw std::runtime_error("offset is larger than file size"); + nbytes -= offset; + } + if(nbytes < 1) + throw std::runtime_error("Invalid size for: " + filename); + is.seekg(offset, std::ios::beg); + + T buffer(nbytes, 0); + if(not is.read(&buffer[0], nbytes)) + throw std::runtime_error("Error reading file: " + filename); + return buffer; +} + +std::vector read_buffer(const std::string& filename, size_t offset = 0, size_t nbytes = 0) +{ + return generic_read_file>(filename, offset, nbytes); +} + +std::string read_string(const std::string& filename) +{ + return generic_read_file(filename); +} + +void write_buffer(const std::string& filename, const char* buffer, std::size_t size) +{ + std::ofstream os(filename); + os.write(buffer, size); +} +void write_buffer(const std::string& filename, const std::vector& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} +void write_string(const std::string& filename, const std::string_view& buffer) +{ + write_buffer(filename, buffer.data(), buffer.size()); +} + +std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device-only"; } + +kernel compile_kernel(const std::vector& srcs, compile_options options) +{ + assert(not srcs.empty()); + tmp_dir td{"compile"}; + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + std::string out; + + for(const auto& src : srcs) + { + std::filesystem::path full_path = td.path / src.path; + std::filesystem::path parent_path = full_path.parent_path(); + std::filesystem::create_directories(parent_path); + write_string(full_path.string(), src.content); + if(src.path.extension().string() == ".cpp") + { + options.flags += " -c " + src.path.filename().string(); + if(out.empty()) + out = src.path.stem().string() + ".o"; + } + } + + options.flags += " -o " + out; + td.execute(compiler() + options.flags); + + auto out_path = td.path / out; + if(not std::filesystem::exists(out_path)) + throw std::runtime_error("Output file missing: " + out); + + auto obj = read_buffer(out_path.string()); + + return kernel{obj.data(), options.kernel_name}; +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/hip.cpp b/codegen/test/rtc/src/hip.cpp new file mode 100644 index 0000000000..10e38c9adb --- /dev/null +++ b/codegen/test/rtc/src/hip.cpp @@ -0,0 +1,102 @@ +#include +#include +#include +#include + +namespace rtc { + +using hip_ptr = RTC_MANAGE_PTR(void, hipFree); + +std::string hip_error(int error) { return hipGetErrorString(static_cast(error)); } + +int get_device_id() +{ + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + throw std::runtime_error("No device"); + return device; +} + +std::string get_device_name() +{ + hipDeviceProp_t props{}; + auto status = hipGetDeviceProperties(&props, get_device_id()); + if(status != hipSuccess) + throw std::runtime_error("Failed to get device properties"); + return props.gcnArchName; +} + +bool is_device_ptr(const void* ptr) +{ + hipPointerAttribute_t attr; + auto status = hipPointerGetAttributes(&attr, ptr); + if(status != hipSuccess) + return false; + return attr.type == hipMemoryTypeDevice; +} + +void gpu_sync() +{ + auto status = hipDeviceSynchronize(); + if(status != hipSuccess) + throw std::runtime_error("hip device synchronization failed: " + hip_error(status)); +} + +std::size_t get_available_gpu_memory() +{ + size_t free; + size_t total; + auto status = hipMemGetInfo(&free, &total); + if(status != hipSuccess) + throw std::runtime_error("Failed getting available memory: " + hip_error(status)); + return free; +} + +std::shared_ptr allocate_gpu(std::size_t sz, bool host) +{ + if(sz > get_available_gpu_memory()) + throw std::runtime_error("Memory not available to allocate buffer: " + std::to_string(sz)); + void* alloc_ptr = nullptr; + auto status = host ? hipHostMalloc(&alloc_ptr, sz) : hipMalloc(&alloc_ptr, sz); + if(status != hipSuccess) + { + if(host) + throw std::runtime_error("Gpu allocation failed: " + hip_error(status)); + else + return allocate_gpu(sz, true); + } + assert(alloc_ptr != nullptr); + std::shared_ptr result = share(hip_ptr{alloc_ptr}); + return result; +} + +std::shared_ptr write_to_gpu(const void* x, std::size_t sz, bool host) +{ + gpu_sync(); + auto result = allocate_gpu(sz, host); + assert(is_device_ptr(result.get())); + assert(not is_device_ptr(x)); + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyHostToDevice); + if(status != hipSuccess) + throw std::runtime_error("Copy to gpu failed: " + hip_error(status)); + return result; +} + +std::shared_ptr read_from_gpu(const void* x, std::size_t sz) +{ + gpu_sync(); + std::shared_ptr result(new char[sz]); + assert(not is_device_ptr(result.get())); + if(not is_device_ptr(x)) + { + throw std::runtime_error( + "read_from_gpu() requires Src buffer to be on the GPU, Copy from gpu failed\n"); + } + auto status = hipMemcpy(result.get(), x, sz, hipMemcpyDeviceToHost); + if(status != hipSuccess) + throw std::runtime_error("Copy from gpu failed: " + hip_error(status)); // NOLINT + return std::static_pointer_cast(result); +} + +} // namespace rtc diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp new file mode 100644 index 0000000000..f4fb19130c --- /dev/null +++ b/codegen/test/rtc/src/kernel.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include + +// extern declare the function since hip/hip_ext.h header is broken +extern hipError_t hipExtModuleLaunchKernel(hipFunction_t, // NOLINT + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + size_t, + hipStream_t, + void**, + void**, + hipEvent_t = nullptr, + hipEvent_t = nullptr, + uint32_t = 0); + +namespace rtc { + +std::vector pack_args(const std::vector& args) +{ + std::vector kernargs; + for(auto&& arg : args) + { + std::size_t n = arg.size; + const auto* p = static_cast(arg.data); + // Insert padding + std::size_t padding = (arg.align - (kernargs.size() % arg.align)) % arg.align; + kernargs.insert(kernargs.end(), padding, 0); + kernargs.insert(kernargs.end(), p, p + n); + } + return kernargs; +} + +using hip_module_ptr = RTC_MANAGE_PTR(hipModule_t, hipModuleUnload); + +struct kernel_impl +{ + hip_module_ptr module = nullptr; + hipFunction_t fun = nullptr; +}; + +hip_module_ptr load_module(const char* image) +{ + hipModule_t raw_m; + auto status = hipModuleLoadData(&raw_m, image); + hip_module_ptr m{raw_m}; + if(status != hipSuccess) + throw std::runtime_error("Failed to load module: " + hip_error(status)); + return m; +} + +kernel::kernel(const char* image, const std::string& name) : impl(std::make_shared()) +{ + impl->module = load_module(image); + auto status = hipModuleGetFunction(&impl->fun, impl->module.get(), name.c_str()); + if(hipSuccess != status) + throw std::runtime_error("Failed to get function: " + name + ": " + hip_error(status)); +} + +void launch_kernel(hipFunction_t fun, + hipStream_t stream, + std::size_t global, + std::size_t local, + void* kernargs, + std::size_t size) +{ + assert(global > 0); + assert(local > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + global, + 1, + 1, + local, + 1, + 1, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + std::vector args) const +{ + assert(impl != nullptr); + void* kernargs = args.data(); + std::size_t size = args.size() * sizeof(void*); + + launch_kernel(impl->fun, stream, global, local, kernargs, size); +} + +void kernel::launch(hipStream_t stream, + std::size_t global, + std::size_t local, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); +} + +} // namespace rtc \ No newline at end of file diff --git a/codegen/test/rtc/src/tmp_dir.cpp b/codegen/test/rtc/src/tmp_dir.cpp new file mode 100644 index 0000000000..3b0f0170e8 --- /dev/null +++ b/codegen/test/rtc/src/tmp_dir.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include + +namespace rtc { +std::string random_string(std::string::size_type length) +{ + static const std::string& chars = "0123456789" + "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::mt19937 rg{std::random_device{}()}; + std::uniform_int_distribution pick(0, chars.length() - 1); + + std::string str(length, 0); + std::generate(str.begin(), str.end(), [&] { return chars[pick(rg)]; }); + + return str; +} + +std::string unique_string(const std::string& prefix) +{ + auto pid = getpid(); + auto tid = std::this_thread::get_id(); + auto clk = std::chrono::steady_clock::now().time_since_epoch().count(); + std::stringstream ss; + ss << std::hex << prefix << "-" << pid << "-" << tid << "-" << clk << "-" << random_string(16); + return ss.str(); +} + +tmp_dir::tmp_dir(const std::string& prefix) + : path(std::filesystem::temp_directory_path() / + unique_string(prefix.empty() ? "ck-rtc" : "ck-rtc-" + prefix)) +{ + std::filesystem::create_directories(this->path); +} + +void tmp_dir::execute(const std::string& cmd) const +{ + std::string s = "cd " + path.string() + "; " + cmd; + std::system(s.c_str()); +} + +tmp_dir::~tmp_dir() { std::filesystem::remove_all(this->path); } + +} // namespace rtc \ No newline at end of file diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 42f8daef71..77ed9625c5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -498,6 +498,86 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + return true; + } + static bool IsSupportedArgument(const Argument& arg) { if(!ck::is_xdl_supported()) @@ -505,87 +585,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD && ABlockTransferSrcVectorDim == 2) - { - if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) - { - // FIXME: not rigorous - if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector laod of B - if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) - { - if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) - { - // FIXME: not rigorous - if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) - { - return false; - } - } - else - { - return false; - } - - // check vector load of Ds - // only support RowMajor for now - bool all_valid = true; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - - if constexpr(!is_same_v) - { - all_valid = false; - } - }); - - if(!all_valid) - { - return false; - } - - // check vector store of E - // only support RowMajor for now - if constexpr(is_same_v) - { - if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) - { - return false; - } - } - else - { - return false; - } - } - - return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and + GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, arg.b_grid_desc_n_k_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, @@ -708,6 +709,178 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD + struct Descriptor + { + static constexpr auto ds_tuple() + { + return transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + DsDesc{}); + } + using AGridDesc_M_K = + remove_cvref_t; + using BGridDesc_N_K = + remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_tuple()))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; + using Block2ETileMap = remove_cvref_t; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k; + BGridDesc_N_K b_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + EGridDesc_M_N e_grid_desc_m_n; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; + + // for checking vector load/store + index_t MRaw; + index_t NRaw; + index_t KRaw; + + bool has_main_k_block_loop = true; + + constexpr Descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_) + : a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)}, + b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)}, + ds_grid_desc_m_n{transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + ds)}, + e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)}, + a_grid_desc_ak0_m_ak1{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)}, + b_grid_desc_bk0_n_bk1{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + transform_tuples( + [&](auto d) constexpr { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); + }, + ds))}, + e_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)}, + block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, + MRaw{e.GetLength(I0)}, + NRaw{e.GetLength(I1)}, + KRaw{a.GetLength(I1)} + { + } + + constexpr bool IsValid() const + { + return GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw); + } + + constexpr index_t GetBlockSize() const { return BlockSize; } + + constexpr index_t GetGridSize() const + { + return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{}) + { + return Descriptor( + a, b, ds, e, a_element_op, b_element_op, cde_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid) + { + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + assert(desc.IsValid()); + if(desc.has_main_k_block_loop) + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + else + { + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + } }; } // namespace device diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 6266fb40f0..a89e14cbdb 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - __host__ __device__ BlockToCTileMap_M00_N0_M01() = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 1) + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) { } @@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01 } template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const { if constexpr(DeviceCTileIndexCheck) return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); @@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -120,18 +120,19 @@ struct BlockToCTileMap_M00_N0_M01Adapt static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt() = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const BlockToCTileMap_M00_N0_M01Adapt&) = - default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(BlockToCTileMap_M00_N0_M01Adapt&&) = - default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + BlockToCTileMap_M00_N0_M01Adapt&&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt& + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) : M_(M), N_(N), M01_(M01) { #if 0 @@ -142,8 +143,9 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ __device__ BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, - index_t M01 = 8) + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) : BlockToCTileMap_M00_N0_M01Adapt( c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) { @@ -164,7 +166,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } @@ -237,8 +239,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt } template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const { return true; // always valid provided that user gets grid size from CalculateGridSize() } @@ -616,7 +618,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt return true; // always valid provided that user gets grid size from CalculateGridSize() } - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } private: index_t M01_; @@ -674,7 +679,7 @@ struct BlockToCTileMap_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -786,7 +791,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01 return true; } - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { if constexpr(DeviceCTileIndexCheck) return true; // validity check moved to kernel @@ -910,7 +915,7 @@ struct OffsettedBlockToCTileMap } template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); } @@ -967,7 +972,7 @@ struct BlockToCTileMap_3DGrid_KSplit } template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const { return true; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 15c30a0dad..c0a3d29f85 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -264,7 +264,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const BGridDesc_N_K& b_grid_desc_n_k, const DsGridDesc_M_N& ds_grid_desc_m_n, const EGridDesc_M_N& e_grid_desc_m_n, - const Block2ETileMap& block_2_etile_map) + const Block2ETileMap&) { static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, @@ -310,10 +310,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } // check block-to-E-tile - if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) - { - return false; - } + // if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + //{ + // return false; + //} // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) // check tensor size: cannot be larger than 2GB each From 1ddc8a841a5f4fbec9656e1a6b3805853211fddc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 5 Mar 2024 21:55:01 -0800 Subject: [PATCH 19/36] Bump rocm-docs-core from 0.35.0 to 0.35.1 in /docs/sphinx (#1187) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.35.0 to 0.35.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.35.0...v0.35.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 1576e54537..93c15a2160 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.35.0 +rocm-docs-core==0.35.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index a8cb087225..8faeac85db 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.35.0 +rocm-docs-core==0.35.1 # via -r requirements.in six==1.16.0 # via From adb3615d1a8db9bfc1a53987443d48df20f87cc7 Mon Sep 17 00:00:00 2001 From: yhuiYH <145490163+yhuiYH@users.noreply.github.com> Date: Thu, 7 Mar 2024 13:08:37 -0500 Subject: [PATCH 20/36] Update CODEOWNERS to use documentation group (#1190) Also had to remove a name --- .github/CODEOWNERS | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e4d0d47a2e..37407cebf1 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,7 +1,7 @@ -* @zjing14 @asroy @junliume @illsilin @carlushuang @aosewski +* @zjing14 @junliume @illsilin @carlushuang @aosewski # Documentation files -docs/* @saadrahim @LisaDelaney -*.md @saadrahim @LisaDelaney -*.rst @saadrahim @LisaDelaney -# Header directory -library/include/* @saadrahim @LisaDelaney +docs/* @ROCm/rocm-documentation +*.md @ROCm/rocm-documentation +*.rst @ROCm/rocm-documentation +# Header directory for Doxygen documentation +library/include/* @ROCm/rocm-documentation From 0e28de9766f29bd8687be6dd9e1ed8894869ee55 Mon Sep 17 00:00:00 2001 From: Lisa Date: Thu, 7 Mar 2024 11:09:17 -0700 Subject: [PATCH 21/36] Update link (#1186) --- docs/dockerhub.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/dockerhub.rst b/docs/dockerhub.rst index 21121f1b82..87eb5a4f81 100644 --- a/docs/dockerhub.rst +++ b/docs/dockerhub.rst @@ -36,7 +36,7 @@ What is inside the image? The docker images have everything you need for running CK including: -* `ROCm `_ +* `ROCm `_ * `CMake `_ * `Compiler `_ * `Composable Kernel library `_ From 363feb482d03fa217adce6a56f5e74a2ae4a00c3 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:05:05 -0600 Subject: [PATCH 22/36] Refactor tolerances for correctness check in gemm op (#1188) * Refactor tolerances for correctness check * Update tolerances * Update host-side gemm * Update reference gemm call --- example/01_gemm/gemm_xdl_fp16_fp8.cpp | 10 ++- example/01_gemm/run_gemm_example.inc | 89 ++++++++++++++++++++++++++- 2 files changed, 95 insertions(+), 4 deletions(-) diff --git a/example/01_gemm/gemm_xdl_fp16_fp8.cpp b/example/01_gemm/gemm_xdl_fp16_fp8.cpp index d3cf3d397a..979a200791 100644 --- a/example/01_gemm/gemm_xdl_fp16_fp8.cpp +++ b/example/01_gemm/gemm_xdl_fp16_fp8.cpp @@ -33,8 +33,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; #include "run_gemm_example.inc" diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 49743a9c43..2837937ead 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -5,6 +5,88 @@ #include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -240,8 +322,11 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) #else c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - return ck::utils::check_err( - c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 1e-1, 1e-1); + return ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); #endif } From 1837040a9c3253ba51506b7b7d236836b6efaa8d Mon Sep 17 00:00:00 2001 From: zjing14 Date: Fri, 8 Mar 2024 19:11:51 -0600 Subject: [PATCH 23/36] Navi3 rel (#1176) * wmma_op + unit test * add arch limitation to wmma test * change arch limitation * Refactor + Add all type unit test(int4 compile failed) * Add f32_16x16x16_bf16 unit test * tempsave * tempsave * tempsave * runtime bug, cannot find symbol * workaround for incorrect HIP warpSize return value * debugging * tempsave * Correctness OK, waiting for optimization * Tidy up + format * temp save * temp save, reproduce the v_bfi_b32 issue * add inline asm for wmmaop test * tidy up * clean some debug purpose code * discard some codes * clang format * clang format * compiler issue fixed + increase tile size * navi3x_multipleD+example * temp save * workable * batchedgemm[OK], groupconv[debug] * groupconv: Sanity check[OK], Performance[Bad] * navi3x_groupconv_need_optimization * create necessary files * save progress * Add Inter-Row thread transfer * save progress * save debugging progress * sanity check pass * fix a host tensor bug and clean up flash-attn code * format * cancel unnecessary change * cancel unnecessary change * cancel unnecessary change * temp save, add asm backend flag to amd_wmma * Mat-A LDS Bypass sanity pass * temp save * gemm sanity fix * Porting new blockwise gemm to flash attention * Example branch provide to compiler team * tempsave * Fix a bug * batched gemm ported * conv A-skip lds ported * Skip B-Lds real gemm * Skip B Lds Gemm + MulD * batched gemm, conv, skip b lds * format * Attn, skip b lds * Change GridwiseOp nam * fix a typo caused bug * Skip A_Lds sanity pass, Skip B_Lds scratch occured * Bug found, intra-row permute off caused * bug found * a fix * disable buffer load due to incorrect 3rd dword * update fmha config, no scratch generated * update 3rd dword * fmha config update * FMHA, add support to gfx1101/gfx1102 * Merge origin dev (#2) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer Co-authored-by: zjing14 * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commit bb5530af91352dca062b791313d9b77700335ae9. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root * fixed quant example (#672) Co-authored-by: root * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * add vector load check * solve conflicts --------- Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: zjing14 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> * Disable SkipLDS & Align AIT api (#3) * fix layernorm, reduction Ops (#4) * [Navi3x] Fix Gridwise_multiple_d operation (#649) * Add CMake Option "USE_OPT_NAVI3X" * fix bug * standardize docs (#655) * Separate bibtex requirement from rocm-docs-core (#656) * separate bibtex requirement from rocm-docs-core * point requirements to source rocm-docs-core repo * Add CMake Option "USE_OPT_NAVI3X" (#647) * Add CMake Option "USE_OPT_NAVI3X" * remove navi3x opt compile option from cmake script * Conv + quantization + tanh (#645) * Rename file. Prepare to support another activation * Add comment for quantization * Extract out_elementop * Add tanh example * Add conv + bias + tanh quantization instance * Add missing parameter * Refine cmake * Add external api and client example * Extract variable in example * Fix the comment --------- Co-authored-by: zjing14 * Add a denorm test fix (#603) * Add type_convert implementations for bf16 * Add the fix for conv_fwd * Add the fix for conv_bwd_data * Add the fix for conv_bwd_weight * Format * Format * Another format * Add a macro to use workaround on MI200 only * Format --------- Co-authored-by: Rosty Geyyer Co-authored-by: zjing14 * simplify karg in device/grid of split-k op (#644) * simplify karg in device/grid split-k op * fix mk_kn_mn instances * add more instances * use name from tensor layout * fix 3rd dword of buffer source descriptor (#659) * add fp64 instances (#658) Co-authored-by: root * Issue #666: Revert "simplify karg in device/grid of split-k op (#644)" (#665) This reverts commit bb5530af91352dca062b791313d9b77700335ae9. * Groupnorm + swish external api (#668) * Rename to proper naming * Add example of groupnorm + swish * Extract duplicate code in example * Add groupnorm + swish instances * Ractor instance generation, split into multiple cpp file * Add external api and client example * Refine profiler message * Use ck math version of exp * Refine problem size in example * Add host version of exp * add a marco to turn on/off denorm fix (off by default) (#673) * add a marco to turn off denorm fix by default * expose the marco --------- Co-authored-by: root * fixed quant example (#672) Co-authored-by: root * Add dependabot config and pin rocm-docs-core (#663) * [gtest] suppress unsafe buffer warn (#670) ref: https://github.com/ROCmSoftwarePlatform/MIOpen/pull/1912 * Add memory index guard in wmma device ops (#667) * Add more macros to turn on/off denorm fix (#678) Co-authored-by: Rosty Geyyer * Fix a typo (#676) * Add (#677) * Allow using ROCm release candidate compilers. (#679) * enable use of rocm5.5 release candidate 4 * upgrade to ROCM5.5 RC5 * try fix the PUB_KEY error, remove the cmake-data package * upgrade to latest cmake version * use private dockerhub repo for rocm5.5 rc5 * add missing bracket * Disable SkipLDS & Align AIT api * Update dependabot config (#682) Co-authored-by: samjwu * update attn api * solve type_convert bug + enable --------- Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: zjing14 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu Co-authored-by: haocwang * fix typo * Fix attention with causal mask * multiple fix, try ait compile * Add A/B not use LDS pipeline * Clang format, Add gfx1101, gfx1102 support of FMHA example * cancel change of format script * 1. Enable 2-stage global Prefetch ( May cause VGPR spilling) 2. Enable FP16 accumulator blockwise_gemm * clang-format * 1. change blockwise gemm loopover direction from kmn to mnk ( ~1% improvement) 2. change kernel timing mode to 50 warmup + 50 timed repeat * Update low level abstration of blockwise gemm wmma * (2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds * (3/5) batched gemm pass, perf bug: skip a lds has lower performance than skip b lds * (4/5) grouped conv pass * (5/5) attention pass, todo: debug lds perf bug * AIT Attention API refactor (#8) * sanity pass * sanity pass 2 * confirm significant performance regression. * turn on all instances * turn off instance format * Fix bug & tunning & format * DML meta, self_attn+cross_attn * sanity pass * remove useless flag * update tile and problem size used in AIT attention * bug fix in grouped conv supporting check * deprecate inline asm wmma * Bug fix: double lds skip * clang-format * Fix errors in 1. example, fmha 2. gridwise pipeline 3. deviceop, fmha, change some containers from vector to array * part2 of previous commit * clang format * API fix of gridwisegemmpipeline * separate array base and vector base attention tensor transformation * fix gemm * clang format * add gemm fp16 instances * Temp save * fpAintB kernel compile pass * Sanity pass. * Temp save * debug code enabled * Fp16AInt8B_GEMM sanity * MQA implementation * GQA-4 example * tempsave * Compile pass * New implementation of fp16Aint8B Gemm, Acheieve similar math throughput with native fp16 Gemm * format * Todo: fix gemm_bilinear_wmma instances compilation bug * Solve a bug when K1=16 * remove unnecessary changes * Remove tensor layout limitation to LDS usage in tesnor contraction * update self-attention and cross-attention * fix a typo of name * Add arch limiter for fp8 gemm * enable fp8 gemm_xdl for all gfx9 targets * temporarily disable gemm_xdl_fp16_fp8 on MI100/200 * fix the cmake logic for gemm_xdl_fp16_fp8 * re-enable the gemm_xdl_fp16_fp8 on MI100/200 --------- Co-authored-by: aska-0096 Co-authored-by: Sam Wu Co-authored-by: Sam Wu Co-authored-by: rocking5566 Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Co-authored-by: Rosty Geyyer Co-authored-by: carlushuang Co-authored-by: root Co-authored-by: Jun Liu Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: samjwu Co-authored-by: haocwang Co-authored-by: illsilin --- example/01_gemm/CMakeLists.txt | 15 +- example/01_gemm/gemm_wmma_fp16.cpp | 47 +- example/01_gemm/run_gemm_example.inc | 16 + .../gemm_bilinear_wmma_fp16.cpp | 87 +- .../gemm_bilinear_wmma_int8.cpp | 87 +- .../CMakeLists.txt | 2 +- .../batched_gemm_bias_e_permute_wmma_fp16.cpp | 87 +- ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 29 +- .../CMakeLists.txt | 10 +- ...e_scale_softmax_gemm_permute_wmma_fp16.cpp | 166 ++ ...m_scale_softmax_gemm_permute_wmma_fp16.cpp | 288 +++ .../cross_attention_forward_wmma_fp16.cpp | 354 ++++ ...uped_query_attention_forward_wmma_fp16.cpp | 302 +++ ...ulti_query_attention_forward_wmma_fp16.cpp | 287 +++ ...d_gemm_scale_softmax_gemm_permute_wmma.inc | 340 ++++ .../run_cross_attention_wmma.inc | 384 ++++ ...n_grouped_query_attention_forward_wmma.inc | 340 ++++ ...run_multi_query_attention_forward_wmma.inc | 339 ++++ .../run_self_attention_wmma.inc | 376 ++++ .../self_attention_forward_wmma_fp16.cpp | 332 ++++ example/64_fpAintB_gemm/CMakeLists.txt | 5 + example/64_fpAintB_gemm/common.hpp | 123 ++ .../64_fpAintB_gemm/fp16int8_gemm_wmma.cpp | 93 + example/64_fpAintB_gemm/run_gemm_example.inc | 172 ++ .../gpu/block/blockwise_gemm_wmma.hpp | 971 ++++----- ...oup_tensor_slice_transfer_v4r1_dequant.hpp | 223 +++ .../gpu/device/device_gemm_dequantB.hpp | 46 + ...d_contraction_multiple_d_wmma_cshuffle.hpp | 321 +-- ...emm_softmax_gemm_permute_wmma_cshuffle.hpp | 1729 +++++++++++++++++ .../device/impl/device_fpAintB_gemm_wmma.hpp | 714 +++++++ .../device_gemm_multiple_d_wmma_cshuffle.hpp | 359 ++-- .../gpu/device/impl/device_gemm_wmma.hpp | 417 ++-- ...conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 6 +- ..._grouped_conv_bwd_weight_wmma_cshuffle.hpp | 10 +- ...uped_conv_fwd_multiple_d_wmma_cshuffle.hpp | 270 ++- ...e_grouped_query_attention_forward_wmma.hpp | 1254 ++++++++++++ ...ice_multi_query_attention_forward_wmma.hpp | 1244 ++++++++++++ .../gpu/device/masking_specialization.hpp | 5 +- .../element/unary_element_wise_operation.hpp | 76 + ...iple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp | 2 +- ...atched_gemm_softmax_gemm_wmma_cshuffle.hpp | 1596 +++++++++++++++ .../gpu/grid/gridwise_fpAintB_gemm_wmma.hpp | 1046 ++++++++++ ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 782 ++++++-- .../grid/gridwise_gemm_pipeline_selector.hpp | 11 +- .../gpu/grid/gridwise_gemm_pipeline_v1.hpp | 410 +++- ...e_gemm_split_k_multiple_d_xdl_cshuffle.hpp | 2 +- .../gpu/grid/gridwise_gemm_wmma.hpp | 734 +++++-- .../threadwise_tensor_slice_transfer.hpp | 135 ++ ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 1066 ++++++++++ .../tensor_operation/gpu/warp/wmma_gemm.hpp | 120 +- ...ransform_contraction_to_gemm_arraybase.hpp | 391 ++++ include/ck/utility/amd_buffer_addressing.hpp | 3 +- include/ck/utility/amd_inline_asm.hpp | 24 +- include/ck/utility/data_type.hpp | 15 + include/ck/utility/type_convert.hpp | 57 + .../cpu/reference_batched_gemm.hpp | 246 +++ .../cpu/reference_fpAintB_gemm.hpp | 177 ++ .../tensor_operation_instance/gpu/gemm.hpp | 24 + .../device_grouped_conv_fwd_wmma_instance.hpp | 100 +- .../gpu/gemm/CMakeLists.txt | 6 + ...emm_wmma_f16_f16_f16_km_kn_mn_instance.cpp | 78 + ...emm_wmma_f16_f16_f16_km_nk_mn_instance.cpp | 78 + ...emm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp | 158 ++ ...emm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp | 78 + ...uffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp | 48 +- ...uffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp | 80 +- .../grouped_conv2d_bwd_data/CMakeLists.txt | 32 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 34 +- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 3 +- test/grouped_convnd_bwd_data/CMakeLists.txt | 2 +- test/grouped_convnd_bwd_weight/CMakeLists.txt | 2 +- 73 files changed, 17542 insertions(+), 2020 deletions(-) create mode 100644 example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc create mode 100644 example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp create mode 100644 example/64_fpAintB_gemm/CMakeLists.txt create mode 100644 example/64_fpAintB_gemm/common.hpp create mode 100644 example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp create mode 100644 example/64_fpAintB_gemm/run_gemm_example.inc create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp create mode 100644 include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 5b71cd1548..2fa8e77462 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16) add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +if(GPU_TARGETS MATCHES "gfx11") add_custom_target(example_gemm_wmma) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16) @@ -53,12 +53,6 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp64) add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp) -add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) - -add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) -add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) - list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) @@ -72,5 +66,12 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) + +add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) +add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) + add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) + diff --git a/example/01_gemm/gemm_wmma_fp16.cpp b/example/01_gemm/gemm_wmma_fp16.cpp index b11fe76ab2..8c52e4f7d7 100644 --- a/example/01_gemm/gemm_wmma_fp16.cpp +++ b/example/01_gemm/gemm_wmma_fp16.cpp @@ -19,15 +19,50 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle -// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| -// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmMNKPadding, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, 1>; + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 2837937ead..b04e4e53a8 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -150,6 +150,22 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(b_k_n); + break; + case 3: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n); + break; + case 4: + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{1.f, 1.f}(b_k_n); + break; + case 5: + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-2.f, 2.f}(b_k_n); + break; default: ck::utils::FillUniformDistribution{-0.1f, 0.1f}(a_m_k); ck::utils::FillUniformDistribution{-0.1f, 0.1f}(b_k_n); diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp index 877792d740..d1b820da7b 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 256, - 128, - 256, - 8, - 8, - 16, - 16, - 4, - 4, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - 1, - 1, - S<1, 32, 1, 8>, - 8>; +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; int main(int argc, char* argv[]) { @@ -264,7 +265,7 @@ int main(int argc, char* argv[]) float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; + << device_op.GetTypeString() << std::endl; e_device_buf.FromDevice(e_m_n_device_result.mData.data()); diff --git a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp index 9f23ad2652..aca136f801 100644 --- a/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp +++ b/example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp @@ -55,7 +55,7 @@ using DDataType = I8; using EDataType = I8; using ALayout = Row; -using BLayout = Row; +using BLayout = Col; using DLayout = Row; using ELayout = Row; @@ -65,48 +65,49 @@ using CDEElementOp = AlphaBetaAdd; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; -using DeviceOpInstance = - ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle, - ELayout, - ADataType, - BDataType, - ck::Tuple, - EDataType, - AccDataType, - CShuffleDataType, - AElementOp, - BElementOp, - CDEElementOp, - GemmSpec, - 32, - 16, - 16, - 4, - 16, - 16, - 16, - 1, - 1, - S<2, 16, 1>, - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 16, - 16, - 1, - S<4, 1, 8>, - S<0, 2, 1>, - S<0, 2, 1>, - 1, - 16, - 2, - 1, - 1, - 1, - S<1, 16, 1, 2>, - 8>; +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle< + ALayout, + BLayout, + ck::Tuple, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + ck::Tuple, + EDataType, + AElementOp, + BElementOp, + CDEElementOp, + GemmSpec, + 2, // Prefetch stage + 128, // BlockSize + 128, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 4, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 2, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; int main(int argc, char* argv[]) { diff --git a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt index 32a87dd200..f343cc1910 100644 --- a/example/29_batched_gemm_bias_e_permute/CMakeLists.txt +++ b/example/29_batched_gemm_bias_e_permute/CMakeLists.txt @@ -1,5 +1,5 @@ add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp) -if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102") +if(GPU_TARGETS MATCHES "gfx11") add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp) endif() diff --git a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp index 62233e5351..2bbf430c4e 100644 --- a/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp +++ b/example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp @@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough; using CDEElementOp = ck::tensor_operation::element_wise::Add; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; -static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed; +static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default; static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default; using DeviceOpInstanceKKNN = @@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN = NumDimK, ADataType, BDataType, - DsDataType, - EDataType, AccDataType, CShuffleDataType, + DsDataType, + EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - ABSpec, - ABSpec, + ASpec, + BSpec, DESpec, - 256, + 1, 128, - 256, - 8, - 8, + 64, + 64, + 64, + 4, 16, 16, + 1, 4, - 4, - S<4, 64, 1>, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, + 4, + 4, true, - S<4, 64, 1>, + S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, - 8, - 8, + 4, + 4, true, 1, 1, - S<1, 32, 1, 8>, + S<1, 64, 1, 2>, 8>; using DeviceOpInstance = DeviceOpInstanceKKNN; @@ -251,6 +253,38 @@ int main(int argc, char* argv[]) ck::index_t K0 = 2048; + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + G0 = std::stoi(argv[4]); + G1 = std::stoi(argv[5]); + M0 = std::stoi(argv[6]); + M1 = std::stoi(argv[7]); + N0 = std::stoi(argv[8]); + N1 = std::stoi(argv[9]); + K0 = std::stoi(argv[10]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n"); + exit(0); + } + // A[G0, G1, M0, M1, K0] std::vector a_gs_ms_ks_lengths{G0, G1, M0, M1, K0}; std::vector a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1}; @@ -266,23 +300,6 @@ int main(int argc, char* argv[]) std::vector e_gs_ms_ns_strides{ G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1}; - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - exit(0); - } Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); Tensor b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); Tensor d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides); diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index 360b2c8947..ca8746bb97 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -42,41 +42,42 @@ using DeviceConvFwdInstance = OutputLayout, InKernelDataType, WeiKernelDataType, - ck::Tuple, - OutKernelDataType, AccDataType, CShuffleDataType, + ck::Tuple, + OutKernelDataType, InElementOp, WeiElementOp, OutElementOp, ConvSpec, // ConvForwardSpecialization GemmSpec, // GemmSpecialization - 256, // BlockSize - 128, // MPerBlock - 128, // NPerBlock - 4, // K0PerBlock + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 64, // NPerBlock + 64, // KPerBlock 8, // K1 16, // MPerWMMA 16, // NPerWMMA 4, // MRepeat - 2, // NRepeat - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + 1, // NRepeat + S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim 8, // ABlockTransferSrcScalarPerVector 8, // ABlockTransferDstScalarPerVector_AK1 true, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 8, // BBlockTransferSrcScalarPerVector 8, // BBlockTransferDstScalarPerVector_BK1 true, // BBlockLdsExtraN - 4, - 2, - S<1, 32, 1, 8>, + 1, + 1, + S<1, 16, 1, 8>, 8>; template @@ -277,9 +278,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) switch(conv_param.num_dim_spatial_) { - case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); + // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); - case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); } return false; diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 2a24abf094..c6cca7b586 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,3 +1,12 @@ +if(GPU_TARGETS MATCHES "gfx11") + add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) + add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) + add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp) + add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp) + add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp) + add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp) +endif() + add_custom_target(example_gemm_scale_softmax_gemm) add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) @@ -20,4 +29,3 @@ add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_sc add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp) add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16) - diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000..2c7bacfc4e --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +using DeviceMHAFactory = + std::tuple, // ABlockTransfer MK -> K0 M K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // B0BlockTransfer LK -> K0 L K1 + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 8, 8>, // B1BlockTransfer NL -> L0 N L1 + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 8, + 1, + false, + 1, // CShuffleMWmmaPerWavePerShuffle + 2, // CShuffleNWmmaPerWavePerShuffle + S<1, 64, 1, 4>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + MaskingSpec> // MaskingSpecialization + >; +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp new file mode 100644 index 0000000000..d9ab645ee9 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..4c92c5497f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 64, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_cross_attention_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..12dcfcc36d --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Grouped Query Attention, +Ainslie, Joshua, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, and Sumit +Sanghai. “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” +arXiv, May 22, 2023. https://doi.org/10.48550/arXiv.2305.13245. + +Example is GQA-4 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; +static constexpr ck::index_t QueryGroupNumber = 4; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceGroupedQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + QueryGroupNumber, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = + ck::tensor_operation::host::ReferenceBatchedGemm_GQA; + +#include "run_grouped_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..694a320a45 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp @@ -0,0 +1,287 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Multi-Query Attention +Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.” arXiv.org, November 6, +2019. https://arxiv.org/abs/1911.02150v1. + +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +// #define CK_MHA_USE_WAVE_1 +// #define CK_MHA_USE_WAVE_2 +// #define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceMultiQueryAttentionForward_Wmma< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 128, 64, 8, 8, + // Gemm 1 + 64, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 4, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_MQA; + +#include "run_multi_query_attention_forward_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc new file mode 100644 index 0000000000..2e77479bcc --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::vector{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, G1, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::vector{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + ck::index_t BatchCount = G0 * G1; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, M, K}); + Tensor b0_g_k_n({BatchCount, K, N}); + Tensor b1_g_n_o({BatchCount, N, O}); + Tensor acc0_g_m_n({BatchCount, M, N}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, M, N}); // scratch object after softmax + Tensor c_g_m_o_host_result({BatchCount, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * G1 + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * G1 + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * G1 + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc new file mode 100644 index 0000000000..9ff4c56e06 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_cross_attention_wmma.inc @@ -0,0 +1,384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t q_sequence_length = 256; + ck::index_t kv_sequence_length = 64; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, q_sequence_length, head_num, head_dim]. Batch dim, outer dim, + // inner dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + q_sequence_length = std::stoi(argv[4]); + kv_sequence_length = std::stoi(argv[5]); + head_dim = std::stoi(argv[6]); + batch_size = std::stoi(argv[7]); + head_num = std::stoi(argv[8]); + + alpha = std::stof(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf( + "arg4 to 8: q_sequence_length, kv_sequence_length, head_dim, batch_size, head_num\n"); + printf("arg9: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, q_sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{ + batch_size, head_num, head_dim, kv_sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{kv_sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, kv_sequence_length, head_num, head_dim] + : std::vector{ + head_num * kv_sequence_length * head_dim, + kv_sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, kv_sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, q_sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{q_sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, q_sequence_length, head_num, head_dim] + : std::vector{ + head_num * q_sequence_length * head_dim, + q_sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, q_sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector kv_gs_ns_ks_lengths{ + batch_size, head_num, kv_sequence_length, 2, head_dim}; + std::vector kv_gs_ns_ks_strides = std::vector{ + kv_sequence_length * head_num * 2 * head_dim, + 2 * head_dim, + head_num * 2 * head_dim, + head_dim, + 1}; // kv layout [batch_size, q_sequence_length, head_num, 2, head_dim] + Tensor kv_gs_ns_ks(kv_gs_ns_ks_lengths, kv_gs_ns_ks_strides); + // merge kv into a packed pointer send to device + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { kv_gs_ns_ks(idx[0], idx[1], idx[3], 1, idx[2]) = self(idx); }); + DeviceMem q_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem kv_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + q_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + kv_device_buf.ToDevice(kv_gs_ns_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeCrossAttnInvoker(); + auto argument = + gemm.MakeCrossAttnArgument(static_cast(q_device_buf.GetDeviceBuffer()), + static_cast(kv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + q_sequence_length, + kv_sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(q_sequence_length) * kv_sequence_length * head_dim * 2 + + size_t(q_sequence_length) * kv_sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * q_sequence_length * head_dim + + sizeof(B0DataType) * head_dim * kv_sequence_length + + sizeof(B1DataType) * kv_sequence_length * head_dim + + sizeof(CDataType) * q_sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, q_sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, kv_sequence_length}); + Tensor b1_g_n_o({BatchCount, kv_sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, q_sequence_length, kv_sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n({BatchCount, + q_sequence_length, + kv_sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, q_sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(kv_sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", q_sequence_length: " << q_sequence_length + << ", kv_sequence_length: " << kv_sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc new file mode 100644 index 0000000000..609d085299 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_query_attention_forward_wmma.inc @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 64; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 4; + ck::index_t G1 = 16; + ck::index_t KV_head = QueryGroupNumber; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = + (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0 * QueryGroupNumber; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_gq_k_n({G0, QueryGroupNumber, K, N}); + Tensor b1_g0_gq_n_o({G0, QueryGroupNumber, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_gq_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_gq_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_gq_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_gq_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc new file mode 100644 index 0000000000..b05915c07f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_multi_query_attention_forward_wmma.inc @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t M = 120; + ck::index_t N = 1000; + ck::index_t K = 64; + ck::index_t O = 128; + + // Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape + // C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) + // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t G0 = 7; + ck::index_t G1 = 13; + ck::index_t KV_head = 1; + + float alpha = 1; + + bool input_permute = false; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 13) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + G0 = std::stoi(argv[8]); + G1 = std::stoi(argv[9]); + + alpha = std::stof(argv[10]); + + input_permute = std::stoi(argv[11]); + output_permute = std::stoi(argv[12]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 11: M, N, K, O, G0, G1\n"); + printf("arg10: scale (alpha)\n"); + printf("arg11 to 12: input / output permute\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{G0, G1, M, K}; + std::vector a_gs_ms_ks_strides = + input_permute + ? std::vector{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::vector{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::vector b0_gs_ns_ks_lengths{G0, KV_head, N, K}; + std::vector b0_gs_ns_ks_strides = + input_permute + ? std::vector{N * KV_head * K, K, KV_head * K, 1} + // B0 layout [G0, N, G1, K] + : std::vector{KV_head * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::vector b1_gs_os_ns_lengths{G0, KV_head, O, N}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{N * KV_head * O, O, 1, KV_head * O} + // B1 layout [G0, N, G1, O] + : std::vector{KV_head * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::vector c_gs_ms_os_lengths{G0, G1, M, O}; + std::vector c_gs_ms_os_strides = + output_permute + ? std::vector{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::vector{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_gs_ms_ks.mData.data()); + b0_device_buf.ToDevice(b0_gs_ns_ks.mData.data()); + b1_device_buf.ToDevice(b1_gs_os_ns.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + G0, + G1, + alpha, + input_permute, + output_permute); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(CDataType) * M * O) * G0 * G1 + + (sizeof(B0DataType) * K * N + sizeof(B1DataType) * N * O) * G0; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g0_g1_m_k({G0, G1, M, K}); + Tensor b0_g0_1_k_n({G0, 1, K, N}); + Tensor b1_g0_1_n_o({G0, 1, N, O}); + Tensor acc0_g0_g1_m_n({G0, G1, M, N}); // scratch object after gemm0 + Tensor a1_g0_g1_m_n({G0, G1, M, N}); // scratch object after softmax + Tensor c_g0_g1_m_o_host_result({G0, G1, M, O}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g0_g1_m_k(idx[0], idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g0_1_k_n(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g0_1_n_o(idx[0], idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument(a_g0_g1_m_k, + b0_g0_1_k_n, + acc0_g0_g1_m_n, + a_element_op, + b0_element_op, + acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(N); + acc0_g0_g1_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[2], idx[3])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = + ref_softmax.MakeArgument(acc0_g0_g1_m_n, a1_g0_g1_m_n, 1, 0, {3}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g0_g1_m_n, + b1_g0_1_n_o, + c_g0_g1_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g0_g1_m_o_host_result(idx); }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MQA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << G0 << ", HeadNum: " << G1 << ", M: " << M + << ", N: " << N << ", K: " << K << ", O: " << O << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc new file mode 100644 index 0000000000..3fdaaebb0f --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/run_self_attention_wmma.inc @@ -0,0 +1,376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +int run(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape for A/B0/B1/C + // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o + ck::index_t sequence_length = 256; + ck::index_t head_dim = 80; + + // Output shape C[batch_size, sequence_length, head_num, head_dim]. Batch dim, outer dim, inner + // dim must match GEMM shape C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o]) C_g0_m_g1_o = + // permute(C_g0_g1_m_o, [0, 2, 1, 3]) + ck::index_t batch_size = 2; + ck::index_t head_num = 8; + + float alpha = 1; + bool input_permute = true; + bool output_permute = true; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + sequence_length = std::stoi(argv[4]); + head_dim = std::stoi(argv[5]); + batch_size = std::stoi(argv[6]); + head_num = std::stoi(argv[7]); + + alpha = std::stof(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 7: sequence_length, head_dim, batch_size, head_num\n"); + printf("arg8: scale (alpha)\n"); + exit(0); + } + + std::vector a_gs_ms_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector a_gs_ms_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // A layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // A layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b0_gs_ns_ks_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector b0_gs_ns_ks_strides = + input_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // B0 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // B0 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector b1_gs_os_ns_lengths{batch_size, head_num, head_dim, sequence_length}; + std::vector b1_gs_os_ns_strides = + input_permute + ? std::vector{sequence_length * head_num * head_dim, + head_dim, + 1, + head_num * head_dim} + // B1 layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + 1, + head_dim}; // B1 layout [batch_size, head_num, sequence_length, head_dim] + + std::vector c_gs_ms_os_lengths{batch_size, head_num, sequence_length, head_dim}; + std::vector c_gs_ms_os_strides = + output_permute ? std::vector{sequence_length * head_num * head_dim, + head_dim, + head_num * head_dim, + 1} + // C layout [batch_size, sequence_length, head_num, head_dim] + : std::vector{ + head_num * sequence_length * head_dim, + sequence_length * head_dim, + head_dim, + 1}; // C layout [batch_size, head_num, sequence_length, head_dim] + + Tensor a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + Tensor b0_gs_ns_ks(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + Tensor b1_gs_os_ns(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + Tensor c_gs_ms_os_host_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + Tensor c_gs_ms_os_device_result(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + std::cout << "a_gs_ms_ks: " << a_gs_ms_ks.mDesc << std::endl; + std::cout << "b0_gs_ns_ks: " << b0_gs_ns_ks.mDesc << std::endl; + std::cout << "b1_gs_os_ns: " << b1_gs_os_ns.mDesc << std::endl; + std::cout << "c_gs_ms_os: " << c_gs_ms_os_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + case 4: // A, B0, B1 1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: // Rand: b1 b0; unit: a + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 6: // Rand: a b0 ; unit: B1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 7: // Rand: a b1 ; unit: b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 8: // Rand: a ; unit: b0 b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 9: // Rand: b0 ; unit: a b1 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 10: // Rand: b1 ; unit: a b0 + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_1{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_1{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + default: + a_gs_ms_ks.GenerateTensorValue(GeneratorTensor_Sequential<2>{}); + b0_gs_ns_ks.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_gs_os_ns.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + std::vector qkv_gs_ms_ks_lengths{ + batch_size, head_num, sequence_length, 3, head_dim}; + std::vector qkv_gs_ms_ks_strides = std::vector{ + sequence_length * head_num * 3 * head_dim, + 3 * head_dim, + head_num * 3 * head_dim, + head_dim, + 1}; // qkv layout [batch_size, sequence_length, head_num, 3, head_dim] + Tensor qkv_gs_ms_ks(qkv_gs_ms_ks_lengths, qkv_gs_ms_ks_strides); + // merge qkv into a packed pointer send to device + a_gs_ms_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 0, idx[3]) = self(idx); }); + b0_gs_ns_ks.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[2], 1, idx[3]) = self(idx); }); + b1_gs_os_ns.ForEach( + [&](auto& self, auto idx) { qkv_gs_ms_ks(idx[0], idx[1], idx[3], 2, idx[2]) = self(idx); }); + DeviceMem qkv_device_buf(sizeof(ADataType) * a_gs_ms_ks.mDesc.GetElementSpaceSize() + + sizeof(B0DataType) * b0_gs_ns_ks.mDesc.GetElementSpaceSize() + + sizeof(B1DataType) * b1_gs_os_ns.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * + c_gs_ms_os_device_result.mDesc.GetElementSpaceSize()); + qkv_device_buf.ToDevice(qkv_gs_ms_ks.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + float best_perf = .0; + float best_time = .0; + int not_pass = 0; + std::string best_kernel = ""; + printf("Verification: %s\n", do_verification ? "ON" : "OFF"); + // TODO ANT: replace array with vector? + ck::static_for<0, std::tuple_size_v, 1>{}([&](auto i) -> void { + const auto device_mha_instance = std::get(DeviceMHAFactory{}); + + using DeviceMHAInstance = ck::remove_cvref_t; + auto gemm = DeviceMHAInstance{}; + auto invoker = gemm.MakeSelfAttnInvoker(); + auto argument = + gemm.MakeSelfAttnArgument(static_cast(qkv_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + batch_size, + sequence_length, + head_num, + head_dim, + alpha); + + // if(!gemm.IsSupportedArgument(argument)) + // { + // std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + // return 0; + // } + + ck::index_t BatchCount = batch_size * head_num; + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(sequence_length) * sequence_length * head_dim * 2 + + size_t(sequence_length) * sequence_length * head_dim * 2) * + BatchCount; + std::size_t num_btype = (sizeof(ADataType) * sequence_length * head_dim + + sizeof(B0DataType) * head_dim * sequence_length + + sizeof(B1DataType) * sequence_length * head_dim + + sizeof(CDataType) * sequence_length * head_dim) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + if(tflops > best_perf) + { + best_perf = tflops; + best_time = ave_time * 1000; + best_kernel = gemm.GetTypeString(); + } + if(do_verification) + { + c_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + Tensor a_g_m_k({BatchCount, sequence_length, head_dim}); + Tensor b0_g_k_n({BatchCount, head_dim, sequence_length}); + Tensor b1_g_n_o({BatchCount, sequence_length, head_dim}); + Tensor acc0_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after gemm0 + Tensor a1_g_m_n( + {BatchCount, sequence_length, sequence_length}); // scratch object after softmax + Tensor c_g_m_o_host_result( + {BatchCount, sequence_length, head_dim}); // scratch object after gemm1 + + // permute + a_gs_ms_ks.ForEach([&](auto& self, auto idx) { + a_g_m_k(idx[0] * head_num + idx[1], idx[2], idx[3]) = self(idx); + }); + b0_gs_ns_ks.ForEach([&](auto& self, auto idx) { + b0_g_k_n(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + b1_gs_os_ns.ForEach([&](auto& self, auto idx) { + b1_g_n_o(idx[0] * head_num + idx[1], idx[3], idx[2]) = self(idx); + }); + + // gemm 0 + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + // masking + const auto mask = typename DeviceMHAInstance::C0MatrixMask(sequence_length); + acc0_g_m_n.ForEach([&](auto& self, auto idx) { + if(mask.IsMaskedElement(idx[1], idx[2])) + self(idx) = -ck::NumericLimits::Infinity(); + }); + + // softmax + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + // gemm1 + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // permute + c_gs_ms_os_host_result.ForEach([&](auto& self, auto idx) { + const size_t& g0 = idx[0]; + const size_t& g1 = idx[1]; + + const size_t g = g0 * head_num + g1; + + self(idx) = c_g_m_o_host_result(g, idx[2], idx[3]); + }); + + // default absolute error and relative error is 0.001 + double rtol = 1e-3; + double atol = 1e-3; + + // when BF16 is taken, set absolute error and relative error to 0.01 + if(std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + { + rtol = 1e-2; + atol = 1e-2; + } + + bool this_run_verification = ck::utils::check_err(c_gs_ms_os_device_result.mData, + c_gs_ms_os_host_result.mData, + "Error: Incorrect results!", + rtol, + atol); + printf("Verification: %s, Pass: %s\n", + do_verification ? "ON" : "OFF", + this_run_verification ? "YES" : "NO"); + + if(!this_run_verification) + { + not_pass = 1; + printf("%d th MHA instance verification Failed \n", i.value); + } + } + }); + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Problem Size: BatchCount: " << batch_size << ", HeadNum: " << head_num + << ", sequence_length: " << sequence_length << ", head_dim: " << head_dim + << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + std::cout << "Best kernel: " << best_kernel << " , " << best_perf << " TFlops , " << best_time + << " us" << std::endl; + std::cout << "---------------------------------------------------------------------------------" + "-----------" + << std::endl; + return not_pass; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp new file mode 100644 index 0000000000..8e037272b8 --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g_k_l) * B1_g_l_n + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using Acc0DataType = F32; +using Acc1DataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; +using Acc0BiasDataType = ck::Tuple<>; +using Acc1BiasDataType = ck::Tuple<>; + +static constexpr ck::index_t NumDimG = 2; +static constexpr ck::index_t NumDimM = 1; +static constexpr ck::index_t NumDimN = 1; +static constexpr ck::index_t NumDimK = 1; +static constexpr ck::index_t NumDimO = 1; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKOPadding; +static constexpr auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + +static constexpr auto TensorSpecA = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB0 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecB1 = ck::tensor_operation::device::TensorSpecialization::Default; +static constexpr auto TensorSpecC = ck::tensor_operation::device::TensorSpecialization::Default; + +// clang-format off +#define CK_MHA_USE_WAVE_1 +#define CK_MHA_USE_WAVE_2 +#define CK_MHA_USE_WAVE_4 +#define CK_MHA_USE_WAVE_8 +using DeviceMHAFactory = + std::tuple< +#ifdef CK_MHA_USE_WAVE_1 + // 1 wave, mrepeat = 1, nrepeat = 2, k/o repeat = 1~5 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 32, + // Gemm 0 + 16, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 2, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 16, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_2 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 64, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 64, + // Gemm 0 + 32, 32, 160, 8, 8, + // Gemm 1 + 80, 32, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 2, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 32, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_4 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 128, 80, 8, 8, + // Gemm 1 + 80, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 8, 5, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 192, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 128, + // Gemm 0 + 64, 64, 48, 8, 8, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 4, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<2, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 64, 1, 2>, 8, + MaskingSpec>, +#endif +#ifdef CK_MHA_USE_WAVE_8 + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle< + NumDimG, NumDimM, NumDimN, NumDimK, NumDimO, + ADataType, B0DataType, B1DataType, CDataType, Acc0BiasDataType, Acc0DataType, Acc1BiasDataType, Acc1DataType, CShuffleDataType, + AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, + GemmSpec, TensorSpecA, TensorSpecB0, TensorSpecB1, TensorSpecC, 1, + 256, + // Gemm 0 + 128, 192, 48, 8,4, + // Gemm 1 + 48, 64, 8, + 16, 16, 16, + // Per repeat = wave_m = wave_num, wave_n = 1 + 1, 12, 3, + // ABlockTransfer MK -> K0 M K1 + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, + // B0BlockTransfer LK -> K0 L K1 + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, + // B1BlockTransfer NL -> L0 N L1 + S<2, 16, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, false, + // CShuffleBlockTransfer MN + 1, 1, S<1, 128, 1, 2>, 8, + MaskingSpec> +#endif + >; +// clang-format on +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +#include "run_self_attention_wmma.inc" + +int main(int argc, char* argv[]) { return run(argc, argv); } diff --git a/example/64_fpAintB_gemm/CMakeLists.txt b/example/64_fpAintB_gemm/CMakeLists.txt new file mode 100644 index 0000000000..89cc2d7f62 --- /dev/null +++ b/example/64_fpAintB_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +if(GPU_TARGETS MATCHES "gfx11") + add_custom_target(example_fpAintB_gemm_wmma) + add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) + add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma) +endif() diff --git a/example/64_fpAintB_gemm/common.hpp b/example/64_fpAintB_gemm/common.hpp new file mode 100644 index 0000000000..4fb4c41d05 --- /dev/null +++ b/example/64_fpAintB_gemm/common.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/fill.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp" + +struct ProblemSize final +{ + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; +}; + +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +template +struct UnsignedWeightPreprocessor +{ +}; + +template <> +struct UnsignedWeightPreprocessor +{ + using UnsignedWeight = Tensor; + using SignedWeight = Tensor; + static UnsignedWeight convert(SignedWeight const& Input) + { + + UnsignedWeight Output = Input.template CopyAsType(); + + auto f_kn = [&](auto k, auto n) { + const uint8_t adder = 128; + int8_t v_signed_weight; + uint8_t v_unsigned_weight; + + ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n)); + v_unsigned_weight = ck::type_convert(v_signed_weight) + adder; + Output(k, n) = v_unsigned_weight; + }; + + make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return Output; + } + + UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); } +}; + +inline bool +parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +{ + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 10) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + + problem_size.M = std::stoi(argv[4]); + problem_size.N = std::stoi(argv[5]); + problem_size.K = std::stoi(argv[6]); + + problem_size.StrideA = std::stoi(argv[7]); + problem_size.StrideB = std::stoi(argv[8]); + problem_size.StrideC = std::stoi(argv[9]); + } + else + { + std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" + << std::endl + << "arg3: time kernel (0=no, 1=yes)" << std::endl + << "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl; + return false; + } + + return true; +} diff --git a/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp new file mode 100644 index 0000000000..9dc97fecd8 --- /dev/null +++ b/example/64_fpAintB_gemm/fp16int8_gemm_wmma.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp" + +// Implementation follows the paper: +// Kim, Young Jin, Rawn Henry, Raffy Fahim, and Hany Hassan Awadalla. “Who Says Elephants Can’t Run: +// Bringing Large Scale MoE Models into Cloud Scale Production.” arXiv, November 17, 2022. +// https://doi.org/10.48550/arXiv.2211.10017. Assume weight (Matrix B) is add preprocess to +// unsigned. + +// The DeviceOp is CDataType = ADataType * Dequant(BDataType) * ScaleDataType +// The HostRef is CDataType = ADataType * Dequant(QuantDataType) * ScaleDataType + +// TODO: Current implementation consume more VGPR than expected. + +using ADataType = ck::half_t; +using QuantDataType = int8_t; +using BDataType = uint8_t; +using ScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = float; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceFpAintBGemm_Wmma_CShuffle + < ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + ScaleDataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmDefault, + 1, // Prefetch stage + 128, // BlockSize + 64, // MPerBlock + 128, // NPerBlock + 64, // KPerBlock + 8, // K1 + 16, // MPerWmma + 16, // NPerWmma + 2, // M-Repeat // M-PerWmma / M-Repeat = M-Wave + 4, // N-Repeat // N-PerWmma / N-Repeat = N-Wave + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 32, 1>, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + 1, // C shuffle (M Repeat) Per store + 1, // C shuffle (N Repeat) Per store + S<1, 32, 1, 4>, + 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferencefpAintBGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/example/64_fpAintB_gemm/run_gemm_example.inc b/example/64_fpAintB_gemm/run_gemm_example.inc new file mode 100644 index 0000000000..dc2bdc18f0 --- /dev/null +++ b/example/64_fpAintB_gemm/run_gemm_example.inc @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ +#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) + static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); +#endif + + using namespace ck::literals; + + auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor quant_b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + // assume scale tensor is [1, n] + Tensor scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{})); + + switch(config.init_method) + { + case 0: break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-1.f, 1.f}(scale_k_n); + break; + case 2: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + break; + default: + ck::utils::FillUniformDistribution{-1.f, 1.f}(a_m_k); + ck::utils::FillUniformDistribution{-1.f, 1.f}(quant_b_k_n); + ck::utils::FillUniformDistribution{-1.f, 1.f}(scale_k_n); + } + + UnsignedWeightPreprocessor preprocessor; + Tensor b_k_n = preprocessor(quant_b_k_n); + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "scale_k_n: " << scale_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + +#ifdef BUILD_INT4_EXAMPLE + DeviceMem a_m_k_device_buf(sizeof(KernelADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(KernelBDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(KernelCDataType) * + c_m_n_device_result.mDesc.GetElementSpaceSize()); + + const Tensor a_m_k_converted(a_m_k); + const Tensor b_k_n_converted(b_k_n); + + a_m_k_device_buf.ToDevice(a_m_k_converted.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_converted.mData.data()); +#else + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem scale_k_n_device_buf(sizeof(ScaleDataType) * scale_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + scale_k_n_device_buf.ToDevice(scale_k_n.mData.data()); +#endif + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument( +#ifdef BUILD_INT4_EXAMPLE + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#else + static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(scale_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), +#endif + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + if(config.do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + quant_b_k_n, + scale_k_n, + c_m_n_host_result, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + +#ifdef BUILD_INT4_EXAMPLE + Tensor c_m_n_device_result_converted(c_m_n_host_result.mDesc); + + c_m_n_device_buf.FromDevice(c_m_n_device_result_converted.mData.data()); + + c_m_n_device_result = c_m_n_device_result_converted.CopyAsType(); + + return ck::utils::check_err(c_m_n_device_result_converted, c_m_n_host_result); +#else + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + return ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#endif + } + + return true; +} + +bool run_gemm_example(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp index b3d45f3d0c..f8ee283c67 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -7,6 +7,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" #include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #define CK_MNK_LOOP @@ -16,25 +17,45 @@ template -/* A: K0PerBlock x MPerBlock x K1 + index_t KPack, + bool AEnableLds = true, + bool BEnableLds = true, + bool TransposeC = false> +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle +struct BlockwiseGemmWMMA { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; static constexpr auto WmmaK = Number<16>{}; using ThisThreadBlock = ThisThreadBlock; @@ -42,18 +63,16 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. static constexpr index_t WaveSize = 32; - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = AEnableLds ? 1 : 2; + static constexpr index_t B_KRow = BEnableLds ? 1 : 2; + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); static constexpr auto wmma_gemm = - WmmaGemm{}; + WmmaGemm{}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); @@ -79,371 +98,39 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); } + // Default, Block buffer in LDS, thread level offset enabled __device__ static auto CalculateAThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); - } - - template - __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - const auto waveId_n = wave_idx[I1]; - - const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); - - constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), - make_tuple(Sequence<0>{}), - make_tuple(Sequence<0, 1, 2>{})); - - const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( - make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; - const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( - make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; - - return make_tuple(c_thread_m, c_thread_n); - } - - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle() - { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, - "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); - - static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && - NPerBlock % (NPerWMMA * NRepeat) == 0, - "wrong!"); - } - - // Thread level, register decriptor. Vector-write - __host__ __device__ static constexpr auto - GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = - wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( - // |MRepeat |MWave |MSubGroup |NRepeat |NWave - // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); - } - - // Provide dimension size - __host__ __device__ static constexpr auto - GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() - { - constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - Number{})); - - return wmma_gemm - .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( - c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); - } - - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - // basic intrinsic to determine loopover direction - if constexpr(MRepeat < NRepeat) + if constexpr(AEnableLds) { - static_for<0, KPerBlock / WmmaK, 1>{}( - [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); } else { - static_for<0, KPerBlock / WmmaK, 1>{}( - [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... - static_for<0, NRepeat, 1>{}([&](auto n0) { - // read B - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, n0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, n0, I0, I0, I0), - b_thread_buf); - static_for<0, MRepeat, 1>{}([&](auto m0) { - // read A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, m0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, m0, I0, I0, I0), - a_thread_buf); - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = - b_thread_buf[Number{}]; - }); - - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - }); - }); - }); + return make_tuple(0, 0, 0, 0, 0, 0); } } - protected: - // A[K0, M0, M1, M2, K1] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // B[K0, N0, N1, N2, K1] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); - - // C[M, N, NumRegWMMA] - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; - - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; -}; - -// block wise level pipe designed for inline asm -template -/* A: K0PerBlock x MPerBlock x K1 - * B: K0PerBlock x NPerBlock x K1 - * C: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs - * KPACK == WMMA_K = 16 - */ -struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto WmmaK = Number<16>{}; - - using ThisThreadBlock = ThisThreadBlock; - - // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. - static constexpr index_t WaveSize = 32; - - static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t KPerBlock = - BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); - static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); - - static constexpr auto wmma_gemm = - WmmaGemm{}; - - static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); - static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); - - StaticBufferTupleOfVector - c_thread_buf_; - - __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } - - __device__ static auto GetWaveIdx() - { - const index_t thread_id = ThisThreadBlock::GetThreadId(); - - constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); - } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const auto wave_idx = GetWaveIdx(); - - const auto waveId_m = wave_idx[I0]; - - const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); - // |KRepeat |MRepeat|MWave |MLane |KPack - return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); - } - __device__ static auto CalculateBThreadOriginDataIndex() { - const auto wave_idx = GetWaveIdx(); + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - const auto waveId_n = wave_idx[I1]; - - const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); - // |KRepeat |NRepeat|Nwave |NLane |KPack - return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } } template @@ -474,10 +161,26 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO return make_tuple(c_thread_m, c_thread_n); } - __host__ __device__ BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO() + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, @@ -487,6 +190,22 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO NPerBlock % (NPerWMMA * NRepeat) == 0, "wrong!"); } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + // Thread level, register decriptor. Vector-write __host__ __device__ static constexpr auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -494,20 +213,19 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); - constexpr auto MSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I0]; - constexpr auto NThreadPerSubGroup = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I1]; - constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; - - return make_naive_tensor_descriptor_packed( + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( // |MRepeat |MWave |MSubGroup |NRepeat |NWave // |NThreadPerSubGroup |MAccVgprs - make_tuple(Number{}, - I1, - MSubGroup, - Number{}, - I1, - NThreadPerSubGroup, - MAccVgprs)); + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); } template @@ -532,6 +250,23 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); } + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + // Provide dimension size __host__ __device__ static constexpr auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() @@ -549,33 +284,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); } - __host__ __device__ static constexpr auto MakeABlockDescriptor_K0_M0_M1_M2_K1() - { - return transform_tensor_descriptor( - AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - - __host__ __device__ static constexpr auto MakeBBlockDescriptor_K0_N0_N1_N2_K1() - { - return transform_tensor_descriptor( - BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); - } - + // Describe how data allocated in thread copy src buffer // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma - static constexpr auto a_block_desc_k0_m0_m1_m2_k1 = MakeABlockDescriptor_K0_M0_M1_M2_K1(); - static constexpr auto b_block_desc_k0_n0_n1_n2_k1 = MakeBBlockDescriptor_K0_N0_N1_N2_K1(); + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; template __device__ void Run(const ABlockBuffer& a_block_buf, @@ -587,268 +299,235 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr auto RepeatDiff = MRepeat - NRepeat; - // Read all Mrepeat, Nrepeat - static_for<0, NRepeat, 1>{}([&](auto iN) { - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); - }); + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - static_for<0, MRepeat, 1>{}([&](auto iM) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(I0, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - // Read Consumed Next inner loop A - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; - static_for{}([&](auto iWmmaK) { - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = c_thread_desc_.CalculateOffset( - make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); - // Read Consumed Next inner loop A - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple( - Number{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); }); - // Read Consumed Next inner loop B - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, Number{}, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - b_thread_buf); }); - - // Stage 1: Cut to Repeat Retangle to Square, assume MRepeat > NRepeat - static_for<0, RepeatDiff, 1>{}([&](auto iCut) { - static_for<0, NRepeat, 1>{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iCut, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - if constexpr(KPerBlock > WmmaK) - { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(iWmmaK + WmmaK) / A_K1>{}, Number{}, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, Number{}, I0, I0, I0), - a_thread_buf); - } - }); - }); - - // Stage 2: Run FIFO fashion loopover in Square - static_for<0, NRepeat, 1>{}([&](auto WmmaInnerloop) { - // Row Repeatation - static_for{}([&](auto iN) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(WmmaInnerloop + RepeatDiff, iN, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - - // Col Repeatation - static_for{}([&](auto iM) { - vector_type a_thread_vec; - vector_type b_thread_vec; - - static_for<0, WmmaK, 1>{}([&](auto iK) { - a_thread_vec.template AsType()(iK) = - a_thread_buf[Number{}]; - b_thread_vec.template AsType()(iK) = - b_thread_buf[Number{}]; - }); - using wmma_input_type_a = typename vector_type::type; - using wmma_input_type_b = typename vector_type::type; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(iM, WmmaInnerloop, 0)); - // s_nop(); - wmma_gemm.template Run( - a_thread_vec.template AsType()(Number<0>{}), - b_thread_vec.template AsType()(Number<0>{}), - c_thread_buf.GetVectorTypeReference(Number{})); - // s_nop(); - }); - }); + } } protected: - // A[M0, M1, M2, K0 = WmmaK] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); - // B[N0, N1, N2, K0 = WmmaK] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); // C[M, N, NumRegWMMA] static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - A_K1, - A_K1>; + template + struct AThreadCopySelector; - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3, 4>, - 4, - B_K1, - B_K1>; + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; - AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; - BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? false : true>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? true : false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp new file mode 100644 index 0000000000..ab826bb041 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer with dequantization + * + * RunRead would load low-precision data and scale data. + * RunWrite would process dequantization process. + * Assume Scale is identical along K-dimension + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1_dequant +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto scale_thread_slice_lengths = + BlockScaleSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_block_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + scale_desc, + make_zero_multi_index(), + scale_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{} && + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetScaleSliceOrigin( + scale_desc, scale_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + // With the assumption, scale scratch is always one + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunScaleRead(scale_desc, scale_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + // We don't prefer use this API directly + /* + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + */ + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + // With the assumption, scale buffer don't need move slice window method + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1_dequant; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp new file mode 100644 index 0000000000..acb18efabf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline +// As input tensor thread buffer declared inside blockwise-gemm pipeline. + +template +struct DeviceGemm_dequantB : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp index b32f3a8daa..d35645c068 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -62,10 +62,10 @@ template struct DeviceBatchedContractionMultipleD_Wmma_CShuffle @@ -123,15 +123,32 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] - static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, - const std::vector& a_gs_ms_ks_strides_vec) + static auto MakeAGridDescriptor(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) { assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); @@ -158,36 +175,72 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for K0, K1, ... const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); - if constexpr(ASpec == TensorSpecialization::Packed) + const auto a_grid_desc_m_k = [&]() { + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) { - auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( - make_tuple(M, K), - make_tuple(a_ms_ks_strides[Number{}], - a_ms_ks_strides[Number{}])); - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] - const auto a_grid_desc_ms_ks = - make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; - // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] - const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( - a_grid_desc_ms_ks, - make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), - make_tuple(mDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] - static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, - const std::vector& b_gs_ns_ks_strides_vec) + static auto MakeBGridDescriptor(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) { assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); @@ -214,30 +267,66 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // lengths for N0, N1, ... const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); - if constexpr(BSpec == TensorSpecialization::Packed) + const auto b_grid_desc_n_k = [&]() { + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) { - auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); - auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); - const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( - make_tuple(N, K), - make_tuple(b_ns_ks_strides[Number{}], - b_ns_ks_strides[Number{}])); - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); } else { - // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] - const auto b_grid_desc_ns_ks = - make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; - // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] - const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( - b_grid_desc_ns_ks, - make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), - make_tuple(nDimIds, kDimIds), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); } } @@ -393,8 +482,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle } // Gridwise descriptor, mapping to whole given provblem. - using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); - using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); @@ -449,45 +536,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EGridDesc_G_M_N e_grid_desc_g_m_n_; }; - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_K0_M_K1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK0 = K / K1; - - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_K0_N_K1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK0 = K / K1; - - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_K0_M_K1 = decltype(DeviceOp::MakeAGridDescriptor_K0_M_K1(AGridDesc_M_K{})); - using BGridDesc_K0_N_K1 = decltype(DeviceOp::MakeBGridDescriptor_K0_N_K1(BGridDesc_N_K{})); + using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -496,8 +549,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -508,9 +561,9 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -523,6 +576,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -531,6 +585,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -564,16 +619,14 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_b_grid_{static_cast(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_m_k_{}, - b_grid_desc_n_k_{}, + a_grid_desc_{}, + b_grid_desc_{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_g_m_n_{ DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, e_grid_desc_g_m_n_{ DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, e_grid_desc_mblock_mperblock_nblock_nperblock{}, block_2_ctile_map_{}, @@ -600,10 +653,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle p_ds_grid_(i) = static_cast(p_ds_grid[i]); }); - a_grid_desc_m_k_ = - DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); - b_grid_desc_n_k_ = - DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); ds_grid_desc_m_n_ = DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); @@ -611,9 +662,6 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); - a_grid_desc_k0_m_k1_ = DeviceOp::MakeAGridDescriptor_K0_M_K1(a_grid_desc_m_k_); - b_grid_desc_k0_n_k1_ = DeviceOp::MakeBGridDescriptor_K0_N_K1(b_grid_desc_n_k_); - block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); ds_grid_desc_mblock_mperblock_nblock_nperblock = @@ -644,16 +692,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle EDataType* p_e_grid_; // Tensor Descriptors - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; DsGridDesc_G_M_N ds_grid_desc_g_m_n_; EGridDesc_G_M_N e_grid_desc_g_m_n_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -686,6 +731,11 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle // Batch Offset ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // for checking vector load/store + // index_t MRaw_; + // index_t NRaw_; + // index_t KRaw_; }; // Invoker @@ -700,8 +750,17 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; - const auto K = - arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -712,8 +771,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle BDataType, typename GridwiseOp::DsGridPointer, EDataType, - DeviceOp::AGridDesc_K0_M_K1, - DeviceOp::BGridDesc_K0_N_K1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, AElementwiseOperation, @@ -733,8 +792,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.p_ds_grid_, arg.p_e_grid_, G, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, arg.e_grid_desc_mblock_mperblock_nblock_nperblock, arg.a_element_op_, @@ -774,6 +833,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle { if constexpr(!(is_same_v || is_same_v)) { + printf("DeviceOp: Arch check failure\n"); return false; } } @@ -782,12 +842,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle return false; } - if(!GridwiseOp::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { + printf("GridwiseOp: Validity check failure\n"); return false; } @@ -800,16 +861,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(ABlockTransferSrcVectorDim == 1) { if(!(arg.a_mz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access A-m check failure\n"); return false; } } else { if(!(arg.a_kz_stride_ == 1 && - arg.a_grid_desc_k0_m_k1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + arg.a_grid_desc_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access A-k check failure\n"); return false; } } @@ -818,16 +881,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle if constexpr(BBlockTransferSrcVectorDim == 1) { if(!(arg.b_nz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-n check failure\n"); return false; } } else { if(!(arg.b_kz_stride_ == 1 && - arg.b_grid_desc_k0_n_k1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) { + printf("DeviceOp: Vector Access B-k check failure\n"); return false; } } @@ -841,6 +906,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) { + printf("DeviceOp: Vector Access D-n check failure\n"); valid_d_access = false; } }); @@ -857,6 +923,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle 0) || CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) { + printf("DeviceOp: Vector Access E-n check failure\n"); return false; } @@ -967,14 +1034,18 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock << ", " + << KPerBlock << ", " << K1 << ", " - << MPerWMMA << ", " - << NPerWMMA << ", " + << MPerWmma << ", " + << NPerWmma << ", " << MRepeat << ", " << NRepeat << ">" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp new file mode 100644 index 0000000000..e218ee5c15 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -0,0 +1,1729 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::array{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::array{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::array{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::array{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size}; + std::array qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length}; + std::array v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size}; + std::array c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("sequence_length: %d\n", sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("qkv_gap: %d\n", qkv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_qkv_grid + 0 * qkv_gap + a_batch_offset, + p_qkv_grid + 1 * qkv_gap + b0_batch_offset, + p_qkv_grid + 2 * qkv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_qkv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Cross-Attention +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + std::array k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size}; + std::array k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length}; + std::array v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("q_sequence_length: %d\n", q_sequence_length); + printf("k_sequence_length: %d\n", kv_sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("kv_gap: %d\n", kv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_q_grid + a_batch_offset, + p_kv_grid + 0 * kv_gap + b0_batch_offset, + p_kv_grid + 1 * kv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_q_grid; + ignore = p_kv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = q_sequence_length; + ignore = kv_sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct SelfAttnArg : public BaseArgument + { + SelfAttnArg(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_qkv_grid_{p_qkv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + sequence_length_{sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_qkv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return SelfAttnArg{ + p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha}; + } + + struct CrossAttnArg : public BaseArgument + { + CrossAttnArg(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_q_grid_{p_q_grid}, + p_kv_grid_{p_kv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + q_sequence_length_{q_sequence_length}, + kv_sequence_length_{kv_sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_q_grid_; + const B0DataType* p_kv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t q_sequence_length_; + index_t kv_sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeCrossAttnArgument(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return CrossAttnArg{p_q_grid, + p_kv_grid, + p_out_grid, + batch_size, + q_sequence_length, + kv_sequence_length, + head_count, + head_size, + alpha}; + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct SelfAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::SelfAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_self_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_qkv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; } + + // Invoker + struct CrossAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::CrossAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_cross_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_q_grid_, + arg.p_kv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.q_sequence_length_, + arg.kv_sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = + kernel_batched_gemm_softmax_gemm_wmma_cshuffle; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000..4385d64c19 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) + +template +struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + // LDS bypass feature not implemented for dequantization pipeline. + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0) + { + // assume Scale is [1, N] + const auto scale_grid_desc_n_k = [&]() { + const auto scale_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw); + }(); + + const auto N = scale_grid_desc_n_k.GetLength(I0); + const auto K = scale_grid_desc_n_k.GetLength(I1); + // When K = 1, it might be scale tensor. + assert(K % K1 == 0 && K != 1); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1 + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseFpAintBGemm_Wmma< + BlockSize, + ADataType, + BDataType, + ScaleDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc, + BGridDesc, + ScaleGridDesc, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const ScaleDataType* p_scale_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_scale_grid_{p_scale_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_{}, + scale_grid_desc_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const ScaleDataType* p_scale_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + ScaleGridDesc scale_grid_desc_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_fpAintB_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + ScaleDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_scale_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.scale_grid_desc_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ScaleDataType* p_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_scale, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}, + {PipelineVersion::weight_only, "weight_only"}}; + + // clang-format off + str << "DeviceFpAintBGemm_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp index fd90c7f1ea..a2af5d6a85 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace tensor_operation { @@ -27,21 +28,22 @@ template struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template @@ -180,13 +252,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD; - using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -195,8 +267,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD(p_b_grid)}, p_ds_grid_{}, p_e_grid_{static_cast(p_e_grid)}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, + a_grid_desc{}, + b_grid_desc{}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{}, ds_grid_desc_mblock_mperblock_nblock_nperblock{}, @@ -278,8 +352,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD{}([&](auto i) { using DLayout = remove_cvref_t>; using DDataType = remove_cvref_t>; @@ -295,8 +369,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; // Last Option is W/O + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + }; if(GridwiseOp::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - remove_reference_t, - true>; // Last Option is W/O - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< - GridwiseOp, - ADataType, - BDataType, - typename GridwiseOp::DsGridPointer, - EDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - remove_reference_t< - typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - remove_reference_t, - false>; - - ave_time = - launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_ds_grid_, - arg.p_e_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -575,8 +606,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp index 98d14caa6d..a7f2305291 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" namespace ck { namespace tensor_operation { @@ -33,13 +34,14 @@ template struct DeviceGemmWmma_CShuffle : public DeviceGemm{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 = Max Vector Access Pixels static constexpr auto K1Number = Number{}; - static constexpr auto matrix_padder = - MatrixPadder{MPerBlock, NPerBlock, K0PerBlock* K1}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; - static auto MakeAGridDescriptor_K0_M_K1(index_t MRaw, index_t KRaw, index_t StrideA) + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(StrideA, I1)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } - else if constexpr(is_same_v) + else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), - make_tuple(I1, StrideA)); + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); } }(); - const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } - static auto MakeBGridDescriptor_K0_N_K1(index_t KRaw, index_t NRaw, index_t StrideB) + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same_v) + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(I1, StrideB)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } else if constexpr(is_same_v) { - return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), - make_tuple(StrideB, I1)); + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); } }(); - const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); assert(K % K1 == 0); - const index_t K0 = K / K1; - return transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) @@ -159,56 +230,58 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm; + using GridwiseGemm = + GridwiseGemm_Wmma; // Argument struct Argument : public BaseArgument @@ -230,7 +303,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; - float ave_time = 0; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) { - const auto kernel = kernel_gemm_wmma< - GridwiseGemm, - ADataType, - BDataType, - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - true>; // Last Option is W/O - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } else { - const auto kernel = kernel_gemm_wmma< - GridwiseGemm, - ADataType, - BDataType, - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t< - typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(stream_config, - kernel, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + return launch_kernel(integral_constant{}); } - - return ave_time; } // polymorphic @@ -413,13 +445,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm || is_same_v)) + if constexpr(!(is_same_v || is_same_v || + is_same_v)) { + printf("DeviceOp err: AccDataType"); return false; } } else { + printf("DeviceOp err: Arch"); return false; } @@ -485,7 +520,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm" - << " NumPrefetch: " + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " << NumPrefetch << ", " << "LoopScheduler: " << LoopSchedToString[LoopSched] << ", " diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 0b3de153c3..b0e0e6da76 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -196,7 +196,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle using EGridDesc_M_N = remove_cvref_t>; // GridwiseGemm - using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -217,7 +217,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, + KPerBlock, MPerWMMA, NPerWMMA, K1, @@ -232,6 +232,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, + true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, @@ -240,6 +241,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, + true, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp index 8850b13d0a..e440eb82a4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -393,12 +393,14 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle using BGridDesc_K0_N_K1 = remove_cvref_t; using CGridDesc_M_N = remove_cvref_t; - using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using CShuffleDataType = AccDataType; + + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, AccDataType, - CDataType, + CShuffleDataType, Tuple<>, CDataType, // InMemory Data Descriptor @@ -414,7 +416,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, + KPerBlock, MPerWMMA, NPerWMMA, K1, @@ -429,6 +431,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, + true, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -437,6 +440,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, + true, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp index ba2a4b0f7a..d70d462e24 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -52,22 +52,23 @@ template struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle @@ -109,11 +109,31 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle static constexpr index_t NumDTensor = DsDataType::Size(); - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr index_t KPerBlock = K0PerBlock * K1; + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = + AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); + static constexpr auto BEnableLds = + BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); static constexpr auto conv_to_gemm_transformer = TransformConvFwdToGemm{}; @@ -122,17 +142,16 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; template - static auto - MakeAGridDescriptor_M_K(const std::array& a_g_n_c_wis_lengths, - const std::array& a_g_n_c_wis_strides, - const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides, - const std::array& e_g_n_k_wos_lengths, - const std::array& e_g_n_k_wos_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads) + static auto MakeAGridDescriptor(const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) { const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(a_g_n_c_wis_lengths, @@ -149,13 +168,44 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto in_gemmm_gemmk_desc = matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); - return in_gemmm_gemmk_desc; + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template - static auto - MakeBGridDescriptor_N_K(const std::array& b_g_k_c_xs_lengths, - const std::array& b_g_k_c_xs_strides) + static auto MakeBGridDescriptor(const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides) { const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(b_g_k_c_xs_lengths, @@ -164,7 +214,39 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const auto wei_gemmn_gemmk_desc = matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); - return wei_gemmn_gemmk_desc; + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } } template @@ -197,53 +279,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // desc for problem definition - using AGridDesc_M_K = remove_cvref_t( - {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>; - using BGridDesc_N_K = remove_cvref_t({}, {}))>; + using AGridDesc = + decltype(DeviceOp::MakeAGridDescriptor({}, {}, {}, {}, {}, {}, {}, {}, {}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); using DsGridDesc_M_N = remove_cvref_t; using EGridDesc_M_N = remove_cvref_t({}, {}))>; - // A desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) - { - const auto M = a_grid_desc_m_k.GetLength(I0); - const auto K = a_grid_desc_m_k.GetLength(I1); - - const auto AK1 = K1; - const auto AK0 = K / AK1; - - return transform_tensor_descriptor(a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - // B desc for source in blockwise copy - template - __host__ __device__ static constexpr auto - MakeBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) - { - const auto N = b_grid_desc_n_k.GetLength(I0); - const auto K = b_grid_desc_n_k.GetLength(I1); - - const auto BK1 = K1; - const auto BK0 = K / BK1; - - return transform_tensor_descriptor(b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - } - - using AGridDesc_AK0_M_AK1 = decltype(DeviceOp::MakeAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{})); - using BGridDesc_BK0_N_BK1 = decltype(DeviceOp::MakeBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{})); - // GridwiseOp - using GridwiseOp = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle< + using GridwiseOp = GridwiseGemmMultipleD_Wmma< // DataType Family ADataType, BDataType, @@ -252,8 +295,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle DsDataType, EDataType, // InMemory Data Descriptor - AGridDesc_AK0_M_AK1, - BGridDesc_BK0_N_BK1, + AGridDesc, + BGridDesc, DsGridDesc_M_N, EGridDesc_M_N, // ElementwiseOp Family @@ -264,9 +307,9 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // Tiling Family MPerBlock, NPerBlock, - K0PerBlock, - MPerWMMA, - NPerWMMA, + KPerBlock, + MPerWmma, + NPerWmma, K1, MRepeat, NRepeat, @@ -279,6 +322,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, + AEnableLds, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, @@ -287,6 +331,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, + BEnableLds, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, @@ -327,23 +372,21 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle p_ds_grid_{}, p_e_grid_{static_cast(p_e)}, num_group_{a_g_n_c_wis_lengths[0]}, - a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_g_n_c_wis_lengths, - a_g_n_c_wis_strides, - b_g_k_c_xs_lengths, - b_g_k_c_xs_strides, - e_g_n_k_wos_lengths, - e_g_n_k_wos_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads)}, - b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_g_k_c_xs_lengths, - b_g_k_c_xs_strides)}, ds_grid_desc_m_n_{}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, - a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, - b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + a_grid_desc_{DeviceOp::MakeAGridDescriptor(a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads)}, + b_grid_desc_{ + DeviceOp::MakeBGridDescriptor(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, e_grid_desc_mblock_mperblock_nblock_nperblock_{}, block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, @@ -395,8 +438,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle void Print() const { - std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; - std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; @@ -411,14 +454,12 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle // tensor descriptors for problem definiton index_t num_group_; - AGridDesc_M_K a_grid_desc_m_k_; - BGridDesc_N_K b_grid_desc_n_k_; DsGridDesc_M_N ds_grid_desc_m_n_; EGridDesc_M_N e_grid_desc_m_n_; // tensor descriptors for block/thread-wise copy - AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; - BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock @@ -465,8 +506,17 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle const index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; - const auto K = - arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); auto launch_kernel = [&](auto has_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; @@ -480,8 +530,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, remove_reference_t, @@ -501,8 +551,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle arg.b_element_op_, arg.cde_element_op_, arg.a_g_n_c_wis_lengths_[0], // Group count - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_etile_map_, @@ -670,8 +720,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle } // check Gridwise GEMM - return GridwiseOp::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, + return GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, arg.ds_grid_desc_m_n_, arg.e_grid_desc_m_n_, arg.block_2_etile_map_); @@ -790,9 +840,19 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle << KPerBlock << ", " << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "ABlockTransferSrcScalarPerVector: " << ABlockTransferSrcScalarPerVector << ", " - << BBlockTransferSrcScalarPerVector - << ">"; + << "BBlockTransferSrcScalarPerVector: " + << BBlockTransferSrcScalarPerVector; // clang-format on return str.str(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp new file mode 100644 index 0000000000..84ad48d4c7 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -0,0 +1,1254 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = QueryGroupNumber; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceGroupedQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(arg.G1_ % QueryGroupNumber != 0) + { + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_grouped_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGroupedQueryAttentionForward_Wmma, " + << "QueryGroupNumber: " + << QueryGroupNumber << ", " + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp new file mode 100644 index 0000000000..b7551e78a2 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -0,0 +1,1244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = 1; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceMultiQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceMultiQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_multi_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_navi3_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceMultiQueryAttentionForward_Wmma" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index d6d6f74abd..0ec55984bc 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -53,7 +53,10 @@ struct MaskOutUpperTrianglePredicate template struct C0MatrixMask_impl { - C0MatrixMask_impl(index_t NRaw) : NRaw_(NRaw), predicate_(MaskOutPredicate{}) {} + __host__ __device__ C0MatrixMask_impl(index_t NRaw) + : NRaw_(NRaw), predicate_(MaskOutPredicate{}) + { + } __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const { diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 33c2cb6c6d..c6d933893e 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -123,6 +123,12 @@ struct PassThrough y = type_convert(x); } + template <> + __host__ __device__ void operator()(uint8_t& y, const uint8_t& x) const + { + y = x; + } + template <> __host__ __device__ void operator()(int8_t& y, const int32_t& x) const { @@ -663,6 +669,76 @@ struct Elu const float alpha_; }; +// support fastconvert of int8 to fp16 + +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index a0924ae3b0..42f7c2a33f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -116,7 +116,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; // ck::Tuple static constexpr auto MakeD0sGridPointer() diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp new file mode 100644 index 0000000000..16717ff819 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -0,0 +1,1596 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] +// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +template +struct GridwiseBatchedGemmSoftmaxGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto L0PerBlock = LTilePerBlock / L1Value; + static constexpr auto AL0 = Number{}; + static constexpr auto AL1 = Number{}; + static constexpr auto BL0 = Number{}; + static constexpr auto BL1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + static constexpr auto WmmaL = 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / AK1; + constexpr auto max_lds_align = AK1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / AK1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + AK1), + make_tuple(Number{} * Number{} * AK1, + Number{} * AK1, + Number{} * AK1, + AK1, + AK1, + AK1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeB0BlockDescriptor() + { + constexpr auto b0_block_desc = [&]() { + if constexpr(B0EnableLds) + { + // K0->L->BK1 Per Block + constexpr auto K0PerBlock = KPerBlock / BK1; + constexpr auto max_lds_align = BK1; + + if constexpr(B0BlockLdsExtraL) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / BK1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BK1), + make_tuple(Number{} * Number{} * BK1, + Number{} * BK1, + Number{} * BK1, + BK1, + BK1, + BK1, + I1)); + } + }(); + + return b0_block_desc; + } + + __host__ __device__ static constexpr auto MakeB1BlockDescriptor() + { + constexpr auto b1_block_desc = [&]() { + if constexpr(B1EnableLds) + { + // L0->N->BL1 Per Block + constexpr auto max_lds_align = BL1; + + if constexpr(B1BlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BL1), + make_tuple(Number{} * BL1, BL1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BL1), max_lds_align); + } + } + else + { + constexpr auto LWmmaPerblock = LPerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL / 2 / BL1; + // LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BL1), + make_tuple(Number{} * Number{} * BL1, + Number{} * BL1, + Number{} * BL1, + BL1, + BL1, + BL1, + I1)); + } + }(); + + return b1_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / AK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep() + { + constexpr auto b0_block_copy_step = [&]() { + if constexpr(B0EnableLds) + { + constexpr auto K0PerBlock = KPerBlock / BK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b0_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep() + { + constexpr auto b1_block_copy_step = [&]() { + if constexpr(B1EnableLds) + { + return make_multi_index(L0PerBlock, 0, 0); + } + else + { + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + + return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b1_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) + { + + constexpr auto b0_wave_desc = [&]() { + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + B0BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b0_wave_desc; + } + + template + __host__ __device__ static constexpr auto + MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) + { + constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); + constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); + constexpr auto A_LRow = I1; + return transform_tensor_descriptor( + A1BlockDesc_AL0_M_AL1{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_LRow)), + make_unmerge_transform(make_tuple(Number{}, I1, I1)), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + template + __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&) + { + + constexpr auto b1_wave_desc = [&]() { + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); + constexpr auto B_LRow = I1; + return transform_tensor_descriptor( + B1BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0); + constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3); + constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b1_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + const index_t gemm0_bytes_end = + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) + + SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType)); + + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + + SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType)); + + const index_t softmax_bytes_end = + SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType); + + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (LPerBlock % (LPerWmma * LRepeat)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetB0ProblemsizeLK = [&]() { + if constexpr(B0EnableLds) + { + return make_tuple(b0_grid_desc.GetLength(I1), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * + b0_grid_desc.GetLength(I5), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) * + b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6)); + } + }; + + const auto GetB1ProblemsizeNL = [&]() { + if constexpr(B1EnableLds) + { + return make_tuple(b1_grid_desc.GetLength(I1), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) * + b1_grid_desc.GetLength(I5), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) * + b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto L = GetB0ProblemsizeLK()(I0); + const auto K = GetAProblemsizeMK()[I1]; + const auto N = GetB1ProblemsizeNL()(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + return false; + } + + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + printf("GridwiseOp: outer loop unsupport\n"); + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(LPerBlock % LTilePerBlock == 0)) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + return false; + } + + const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + printf("GridwiseOp: inner loop unsupport\n"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b0_block_space_size_aligned = + B0EnableLds ? math::integer_least_multiple( + MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + static constexpr auto b1_block_space_size_aligned = + B1EnableLds ? math::integer_least_multiple( + MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a_block_space_size_aligned; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + // Feature to add, IntraThread Reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const B0ElementwiseOperation& b0_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const C0MatrixMask& c0_matrix_mask, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// set up Gemm0 +/*******************************************************************************/ + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b0_block_desc = MakeB0BlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto AK0PerBlock = KPerBlock/ AK1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/AK1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b0_block_trait = [&](){ + if constexpr(B0EnableLds) + { + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + SharedMemTrait::b0_block_space_size_aligned); + + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0DataType, + B0DataType, + decltype(b0_grid_desc), + decltype(b0_block_desc), + B0BlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + 1, + 1, + B0ThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b0_grid_desc, + make_multi_index(0, 0, 0), + b0_element_op, + b0_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/BK1Value; + auto b0_block_buf = make_static_buffer( + b0_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b0_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B0BlockTransferSrcScalarPerVector, + B0ThreadTransferSrcResetCoordinateAfterRun, + true>( + b0_grid_desc, + make_multi_index(0, + 0/(LWaves * LPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b0_block_buf = b0_block_trait()[I0]; + auto b0_blockwise_copy = b0_block_trait()[I1]; + +/*******************************************************************************/ + // Gemm0 + constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK); + + auto blockwise_gemm0 = BlockwiseGemmWMMA< + BlockSize, + ADataType, + B0DataType, + Acc0DataType, + decltype(MakeAWaveDescriptor(a_block_desc)), + decltype(MakeB0WaveDescriptor(b0_block_desc)), + MPerBlock, + LPerBlock, + KPerBlock, + MPerWmma, + LPerWmma, + MRepeat, + LRepeat, + KPack, + AEnableLds, + B0EnableLds, + true>{}; // C' = B' x A' + + + // Prepare Register for A*B0 matrix + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( + acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), + make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), + make_pass_through_transform(laccvgprs)), + make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep(); + + const auto a_block_reset_copy_step = [&](){ + if constexpr(AEnableLds){ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); + } + else{ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0); + } + }(); + + const auto b0_block_reset_copy_step = [&](){ + if constexpr(B0EnableLds){ + return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); + } + else{ + return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0); + } + }(); + + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); +/*******************************************************************************/ +// softmax +/*******************************************************************************/ + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + // get acc0 7D thread cluster + constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() / + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0); + constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1); + constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2); + constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3); + constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4); + constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5); + constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6); + // get acc0 thread map + constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform( + make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_l_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs)); + + constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); +/*******************************************************************************/ +// set up Gemm1 +/*******************************************************************************/ + // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm + // A1 matrix in VGPR + constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( + Number{}, + Number{}, + Number{}); + + constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; + constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; + constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; + + constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( + make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1), + make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1)); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + ADataType, + decltype(acc0_thread_desc_l0perblock_mperblock_l1), + decltype(a1_thread_desc_l0perblock_mperblock_l1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2>, + 2, + laccvgprs>{tensor_operation::element_wise::PassThrough{}}; + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); + + constexpr auto b1_block_desc = MakeB1BlockDescriptor(); + + auto b1_block_trait = [&](){ + if constexpr(B1EnableLds) + { + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + SharedMemTrait::b1_block_space_size_aligned); + + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, +/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, +/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ B1DataType, +/* typename DstData, */ B1DataType, +/* typename SrcDesc, */ decltype(b1_grid_desc), +/* typename DstDesc, */ decltype(b1_block_desc), +/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL/2/L1Value; + auto b1_block_buf = make_static_buffer( + b1_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b1_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B1BlockTransferSrcScalarPerVector, + B1ThreadTransferSrcResetCoordinateAfterRun, + true>( + b1_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + }; + + auto b1_block_buf = b1_block_trait()[I0]; + auto b1_blockwise_copy = b1_block_trait()[I1]; + + constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep(); + + auto blockwise_gemm1 = + BlockwiseGemmWMMA{make_tuple(0, 0, 0, 0, 0, 0)}; + + auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const auto L = [&](){ + if constexpr(B0EnableLds){ + return b0_grid_desc.GetLength(I1); + } + else{ + return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5); + } + }(); + + const index_t num_gemm1_l_block_outer_loop = L / LPerBlock; + constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; + + // Initialize C + StaticBuffer c_thread_buf; + c_thread_buf.Clear(); + +/*******************************************************************************/ + // + // Kernel Main Stage + // + // Flash Attention + // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). + index_t gemm1_l_block_outer_index = 0; + // Outer loop, along GEMM_L + // Inner loop, along GEMM_K + do{ + auto l_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock)) + { + continue; + } + // gemm0 start, A-B swaped + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b0_grid_desc, + b0_block_desc, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + KBlockMainLoop); + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 7d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + // 7d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + constexpr auto MREPEAT = c_block_lengths[I0]; + constexpr auto MWAVE = c_block_lengths[I1]; + constexpr auto MTHREADSubGroup = c_block_lengths[I2]; + constexpr auto LREPEAT = c_block_lengths[I3]; + constexpr auto LWAVE = c_block_lengths[I4]; + constexpr auto LSUBGROUP = c_block_lengths[I5]; + constexpr auto LACCVGPRS = c_block_lengths[I6]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D( + Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)), + make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto l_global = l_local + l_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, l_global)) + { + acc0_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); + } + }); + } + else + { static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + } + + block_sync_lds(); + // Tiled softmax start + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc0_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + + // main body + if constexpr(num_gemm1_l_block_inner_loop > 1) + { + static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { + // Data cast from Acc0DataType to ADataType happen here + a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple( + Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, + c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, + a_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc, + b0_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma + )), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NSubGroup, + NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 8, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp new file mode 100644 index 0000000000..67e211ef8d --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,1046 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ + defined(__gfx1102__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_scale_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + scale_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_scale_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = scale_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx1100__)) +} + +// Assume B is Col-Major +template +struct GridwiseFpAintBGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and Dequantized B: be careful of DataType + // scale would not put into LDS. + using LDS_ADataType = ADataType; + using LDS_BDataType = ADataType; + using LDS_CDataType = CShuffleDataType; + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + // B would be dequantize to ADataType before enter LDS + // b_lds_offset = LDS size allocated for a in byte / LDS_BDataType + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(LDS_ADataType) / + sizeof(LDS_BDataType); + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(LDS_CDataType), + a_block_space_size_aligned * sizeof(LDS_ADataType) + + b_block_space_size_aligned * sizeof(LDS_BDataType)); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const ScaleGridDesc& scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto scale_grid_buf = make_dynamic_buffer( + p_scale_grid, scale_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1_dequant, +/* typename BlockScaleSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1, +/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ BDataType, +/* typename ScaleData, */ ScaleDataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(b_grid_desc), +/* typename ScaleDesc, */ decltype(scale_grid_desc), +/* typename DstDesc, */ decltype(b_block_desc), +/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector, +/* index_t ScaleScalarPerVector, */ 1, +/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t ScaleScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + scale_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + ck::tensor_operation::element_wise::PassThrough{}, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + scale_grid_desc, + scale_grid_buf, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index f514e3a119..82d010a99a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -45,8 +45,8 @@ __global__ void const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock @@ -69,7 +69,7 @@ __global__ void const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; DsPointer p_ds_grid_grp; @@ -84,8 +84,8 @@ __global__ void p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock_, a_element_op, @@ -98,8 +98,8 @@ __global__ void ignore = p_ds_grid; ignore = p_e_grid; ignore = batch_count; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; ignore = a_element_op; @@ -115,8 +115,8 @@ template (compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( @@ -170,20 +169,16 @@ __global__ void DsPointer p_ds_grid_grp; - // printf("before allocate pointer d"); - static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - // printf("before entry"); - GridwiseOp::template Run(p_a_grid + a_batch_offset, p_b_grid + b_batch_offset, p_ds_grid_grp, p_e_grid + e_batch_offset, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -199,8 +194,8 @@ __global__ void ignore = a_element_op; ignore = b_element_op; ignore = cde_element_op; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = block_2_etile_map; @@ -213,8 +208,8 @@ template (p_a_grid, p_b_grid, p_ds_grid, p_e_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, ds_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, @@ -263,8 +258,8 @@ __global__ void ignore = p_b_grid; ignore = p_ds_grid; ignore = p_e_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; @@ -282,8 +277,8 @@ template < // DataType Family typename DsDataType, typename EDataType, // InMemory Data Descriptor - typename AGridDesc_K0_M_K1, - typename BGridDesc_K0_N_K1, + typename AGridDesc, + typename BGridDesc, typename DsGridDesc_M_N, typename EGridDesc_M_N, // ElementwiseOp Family @@ -294,7 +289,7 @@ template < // DataType Family // Tiling Family index_t MPerBlock, index_t NPerBlock, - index_t K0PerBlock, + index_t KPerBlock, index_t MPerWmma, index_t NPerWmma, index_t K1Value, @@ -309,6 +304,7 @@ template < // DataType Family index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_K1, bool AThreadTransferSrcResetCoordinateAfterRun, + bool AEnableLds, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_K0_N_K1, typename BBlockTransferThreadClusterArrangeOrder, @@ -317,6 +313,7 @@ template < // DataType Family index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_K1, bool BThreadTransferSrcResetCoordinateAfterRun, + bool BEnableLds, bool BBlockLdsExtraN, index_t CShuffleMRepeatPerShuffle, index_t CShuffleNRepeatPerShuffle, @@ -325,7 +322,7 @@ template < // DataType Family index_t NumGemmKPrefetchStage = 1, LoopScheduler LoopSched = make_default_loop_scheduler(), PipelineVersion PipelineVer = PipelineVersion::v1> -struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle +struct GridwiseGemmMultipleD_Wmma { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -341,53 +338,233 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // K1 should be Number<...> static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return a_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor() { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; } __host__ __device__ static constexpr auto @@ -419,43 +596,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle Number{}); } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto c_block_space_size_aligned = math::integer_least_multiple( - cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize(), - max_lds_align); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) + - b_block_space_size_aligned * sizeof(BDataType)), - c_block_space_size_aligned * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // CheckValidity for kernels without multi D template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const EGridDesc_M_N& e_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -464,20 +610,55 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { @@ -492,8 +673,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { return false; @@ -502,17 +683,57 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle return true; } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const DsGridDesc_M_N& ds_grid_desc_m_n, - const EGridDesc_M_N& e_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - bool valid = true; + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + bool valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && N == ds_grid_desc_m_n[i].GetLength(I1)); @@ -520,16 +741,52 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle if(!valid) { + printf("GridwiseOp: D descriptor dimension check failure\n"); return false; } - return CheckValidity( - a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map); + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; } __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -542,9 +799,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const auto M = e_grid_desc_m_n.GetLength(I0); const auto N = e_grid_desc_m_n.GetLength(I1); - const auto MBlock = M / MPerBlock; - const auto NBlock = N / NPerBlock; - + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( e_grid_desc_m_n, make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), @@ -575,6 +831,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle e_grid_desc_m_n); } + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -591,8 +878,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle DsGridPointer p_ds_grid, EDataType* __restrict__ p_e_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& @@ -602,14 +889,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle const CDEElementwiseOperation& cde_element_op, const Block2CTileMap& block_2_ctile_map) { - // printf("safe entry"); // clang-format off /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); const auto ds_grid_buf = generate_tuple( [&](auto i) { return make_dynamic_buffer( @@ -635,13 +921,30 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /*******************************************************************************/ // BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc.GetElementSpaceSize()); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -661,92 +964,189 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b_block_desc.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp index ecbcb61f3e..567c42362c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -17,18 +17,21 @@ enum struct PipelineVersion v2, // v3 is only used in the Stream-K implementation. v4, + weight_only, }; template + LoopScheduler LoopSched = LoopScheduler::Default, + bool AEnableLds = true, + bool BEnableLds = true> constexpr auto GridwiseGemmPipeline_Selector() { if constexpr(PipelineVer == PipelineVersion::v1) { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { @@ -43,6 +46,10 @@ constexpr auto GridwiseGemmPipeline_Selector() { return GridwiseGemmPipeline_v4{}; } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return GridwiseGemmPipeline_v1_WeightOnly{}; + } else { std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp index 754a3e89c9..0cdb7ce2ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -9,12 +9,12 @@ namespace ck { -template +template struct GridwiseGemmPipeline_v1; // 1-stage prefetch template <> -struct GridwiseGemmPipeline_v1<1> +struct GridwiseGemmPipeline_v1<1, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -108,7 +108,7 @@ struct GridwiseGemmPipeline_v1<1> // 2-stage prefetch template <> -struct GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipeline_v1<2, true, true> { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -254,6 +254,406 @@ struct GridwiseGemmPipeline_v1<2> } }; +template <> +struct GridwiseGemmPipeline_v1<1, false, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_block_buf = a_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, true, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, false, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_block_buf = a_block_buf_switch; + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template +struct GridwiseGemmPipeline_v1_WeightOnly; + +template <> +struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleGridBuffer& scale_grid_buf, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Global Prefetch Stage 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + // Scale read once + b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // Dequantization fused in blockwise_copy + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + template struct GridwiseGemmPipelineInterwave_v1; @@ -349,7 +749,7 @@ struct GridwiseGemmPipelineInterwave_v1<1> // Note: 2 stage prefetch not optimized for inter-wave loop scheduler template <> -struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true> { }; @@ -359,7 +759,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector() { if constexpr(LoopSched == LoopScheduler::Default) { - return GridwiseGemmPipeline_v1{}; + return GridwiseGemmPipeline_v1{}; } else if constexpr(LoopSched == LoopScheduler::Interwave) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index e7dc0d3eb0..0078660556 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -93,7 +93,7 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 066cfc62f2..8e4117593c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -18,11 +18,11 @@ namespace ck { template (p_a_grid, p_b_grid, p_c_grid, p_shared, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, + a_grid_desc, + b_grid_desc, c_grid_desc_mblock_mperblock_nblock_nperblock, a_element_op, b_element_op, @@ -67,8 +63,8 @@ __global__ void ignore = p_a_grid; ignore = p_b_grid; ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; + ignore = a_grid_desc; + ignore = b_grid_desc; ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; ignore = a_element_op; ignore = b_element_op; @@ -78,21 +74,21 @@ __global__ void } template -struct GridwiseGemm_k0mk1_k0nk1_mn_wmma +struct GridwiseGemm_Wmma { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -132,103 +130,277 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma static constexpr auto I6 = Number<6>{}; static constexpr auto I7 = Number<7>{}; - // K1 should be Number<...> + // FIX ME: To be deprecated static constexpr auto K1 = Number{}; + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + using ThisThreadBlock = ThisThreadBlock; - using GridwiseGemmPipe = remove_cvref_t< - decltype(GridwiseGemmPipeline_Selector())>; + using GridwiseGemmPipe = + remove_cvref_t())>; - __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() { - constexpr auto max_lds_align = K1; - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_k0perblock_mperblock_k1 = [&]() { - if constexpr(ABlockLdsExtraM) + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return a_block_desc_k0perblock_mperblock_k1; + return a_block_desc; } - __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + __host__ __device__ static constexpr auto MakeBBlockDescriptor() { - constexpr auto max_lds_align = K1; - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_k0perblock_nperblock_k1 = [&]() { - if constexpr(BBlockLdsExtraN) + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) { - return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), - make_tuple(Number{} * K1, K1, I1)); + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } } else { - return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); } }(); - return b_block_desc_k0perblock_nperblock_k1; + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); + constexpr auto B_KRow = I1; + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; } __host__ __device__ static constexpr auto // *Caution Here repeat is shuffle repeat GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() { - constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = make_naive_tensor_descriptor_packed( make_tuple(I1, - Number{}, + Number{}, I1, - Number{})); + Number{})); return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; } - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_k0perblock_mperblock_k1 = - GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - - constexpr auto b_block_desc_k0perblock_nperblock_k1 = - GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - - constexpr auto max_lds_align = K1; - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize(), max_lds_align); - - return (a_block_space_size_aligned * sizeof(FloatA) + - b_block_space_size_aligned * sizeof(FloatB)); - } - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template - __host__ __device__ static constexpr bool - CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, - const CGridDesc_M_N& c_grid_desc_m_n, - const Block2CTileMap& block_2_ctile_map) + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) { static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); @@ -237,23 +409,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma (NPerBlock % (NRepeat * NPerWmma)) == 0, "Invalid tuning param!"); - const auto M = a_grid_desc_k0_m_k1.GetLength(I1); - const auto N = b_grid_desc_k0_n_k1.GetLength(I1); - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && - K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && - K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); return false; + } - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); return false; + } // check gridwise gemm pipeline - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { + printf("GridwiseOp err: Pipeline not support this k_loop"); return false; } @@ -265,8 +480,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) constexpr long_index_t TwoGB = (long_index_t{1} << 31); - if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatA) <= TwoGB && - b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatB) <= TwoGB)) + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) { return false; } @@ -275,7 +490,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) { - const index_t num_loop = K / (K0PerBlock * K1); + const index_t num_loop = K / KPerBlock; return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); } @@ -313,13 +528,44 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma using DefaultBlock2CTileMap = remove_cvref_t; + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + template - __device__ static void Run(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, void* __restrict__ p_shared, - const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation& a_element_op, @@ -331,9 +577,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /*******************************************************************************/ // Memory buffer zone. const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + p_a_grid, a_grid_desc.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + p_b_grid, b_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -351,24 +597,41 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); /*******************************************************************************/ -// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy - const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); - constexpr auto max_lds_align = K1; - constexpr auto a_block_desc_k0perblock_mperblock_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); - constexpr auto b_block_desc_k0perblock_nperblock_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, /* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, /* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, -/* typename SrcData, */ FloatA, -/* typename DstData, */ FloatA, -/* typename SrcDesc, */ decltype(a_grid_desc_k0_m_k1), -/* typename DstDesc, */ decltype(a_block_desc_k0perblock_mperblock_k1), +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), /* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, /* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, /* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, @@ -378,99 +641,197 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma /* index_t SrcScalarStrideInVector, */ 1, /* index_t DstScalarStrideInVector, */ 1, /* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, -/* bool ThreadTransferDstResetCoordinateAfterRun, */ true>( - a_grid_desc_k0_m_k1, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, make_multi_index(0, m_block_data_idx_on_grid, 0), a_element_op, - a_block_desc_k0perblock_mperblock_k1, + a_block_desc, make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{}); - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_K0_N_K1, - BBlockTransferThreadClusterArrangeOrder, - FloatB, - FloatB, - decltype(b_grid_desc_k0_n_k1), - decltype(b_block_desc_k0perblock_nperblock_k1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_grid_desc_k0_n_k1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_k0perblock_nperblock_k1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; /*******************************************************************************/ // GEMM - constexpr auto WmmaK = 16; constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); auto blockwise_gemm = - BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle{}; + BlockwiseGemmWMMA{}; // Prepare Register for C matrix auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); -/*******************************************************************************/ - constexpr auto a_block_space_size_aligned = math::integer_least_multiple(a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize(), max_lds_align); - // LDS allocation for A and B: be careful of alignment - auto a_block_buf = make_dynamic_buffer(static_cast(p_shared), a_block_desc_k0perblock_mperblock_k1.GetElementSpaceSize()); - auto b_block_buf = make_dynamic_buffer(static_cast(p_shared) + a_block_space_size_aligned, b_block_desc_k0perblock_nperblock_k1.GetElementSpaceSize()); - +/*******************************************************************************/ // Shift Per SUB_K - constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); // gridwise GEMM pipeline - const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); - GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, - a_block_desc_k0perblock_mperblock_k1, + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, a_blockwise_copy, a_grid_buf, a_block_buf, a_block_slice_copy_step, - b_grid_desc_k0_n_k1, - b_block_desc_k0perblock_nperblock_k1, + b_grid_desc, + b_block_desc, b_blockwise_copy, b_grid_buf, b_block_buf, b_block_slice_copy_step, blockwise_gemm, c_thread_buf, - K0BlockMainLoop); + KBlockMainLoop); /*******************************************************************************/ // write out to C, implement shuffle { + // C mapping in single thread. constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - // This API Provide All dimension (size) you need + // C mapping in single block constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); @@ -485,8 +846,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, @@ -532,8 +893,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, // BlockSliceLengths, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - FloatCShuffle, // typename SrcData, - FloatC, // typename DstData, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), Sequence<0, 1, 2, 3>, // typename DimAccessOrder, @@ -636,6 +997,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma if constexpr(access_id < num_access - 1) { constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 608679a4fa..3fdf686523 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1333,4 +1333,139 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ElementwiseOperation element_op_; }; +// Specilized for WMMA +// A single Wave32 is composed by double row +// Data exchange allowed between these two rows +// This RowLane Dst buf will be filled from two Src buf +// SrcA: From specific thread buffer hold by This RowLane on This Row +// SrcB: From specific thread buffer hold by This RowLane on The other Row +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); + + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } + }); + }); + } + ElementwiseOperation element_op_{}; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp new file mode 100644 index 0000000000..174b82f870 --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst_idle +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +// 5. Dequantization happened between read and write. +template +struct ThreadwiseTensorSliceTransfer_v3r1_dequant +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + scale_element_op_(scale_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, + const Index& scale_slice_origin_idx) + { + scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! ScaleBuffer and ScaleData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_scale_access_lengths = + container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); + + // make forward steps + const auto scale_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto scale_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_scale_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_scale_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate scale data index + constexpr auto scale_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] + : ordered_scale_access_lengths[i] - 1 - + ordered_scale_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * + scale_scalar_per_access; + }(); + + constexpr auto scale_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + scale_desc, scale_coord_); + + using scale_vector_type = vector_type_maker_t; + using scale_vector_t = typename scale_vector_type::type; + + // copy data from scale_buf into scale_vector_container + auto scale_vector_container = scale_vector_type{ + scale_buf.template Get(scale_coord_.GetOffset(), is_scale_valid)}; + + // copy data from scale_vector_container into scale_thread_scratch_ + scale_thread_scratch_.template SetAsType( + scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = + ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move scale coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_forward_steps[scale_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_backward_steps[scale_dim_access_order[i]]); + } + } + }); + }); + + // don't need to move scale coordinate back to slice origin + /* + if constexpr(SrcResetCoordinateAfterRun) + { + const auto scale_reset_step = + make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); + + move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); + } + */ + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + + // Do fast numeric convert + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using src_converted_vector_type = vector_type_maker_t; + using src_converted_vector_t = typename src_converted_vector_type::type; + // Vector-wise type convert + static_ford{}([&](auto access_idx) { + auto src_vector_container = src_vector_type{ + src_thread_scratch_tuple_[thread_scratch_id].template GetAsType( + access_idx)}; + + auto src_converted_vector_container = + src_converted_vector_type{fast_numeric_converter(src_vector_container)}; + + src_converted_thread_scratch_.template SetAsType( + access_idx, + src_converted_vector_container.template AsType()[I0]); + }); + + // Element-scale operation, expect packed multiplication + static_ford{}([&](auto idx) { + DstData dst_v; + constexpr auto scale_idx = Sequence{}; + // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), + // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + src_element_op_(dst_v, + src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); + dst_thread_scratch_(idx) = dst_v; + }); +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetScaleThreadScratchDescriptor() + { + + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(scale_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(scale_access_lengths_and_vector_length[i], + scale_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto scale_thread_scratch_desc_ = + decltype(GetScaleThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + /* + template + struct ScaleThreadScratchDesc{}; + */ + + // Registers, contain raw data loaded from global buffer + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain fast converted data + using SrcThreadConvertedScratch = + StaticTensorTupleOfVectorBuffer; + + // Registers, contain scale data + using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain dequantized data + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + using FastTypeConverter = tensor_operation::element_wise:: + FastNumericArrayConverter; + + StaticallyIndexedArray src_thread_scratch_tuple_; + SrcThreadConvertedScratch src_converted_thread_scratch_; + ScaleThreadScratch scale_thread_scratch_; + + DstThreadScratch dst_thread_scratch_; + FastTypeConverter fast_numeric_converter; + + SrcCoord src_coord_; + ScaleCoord scale_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const ScaleElementwiseOperation scale_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp index 814b4167b8..70fbcec10f 100644 --- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -89,6 +89,7 @@ struct wmma_type @@ -129,6 +130,7 @@ struct wmma_type @@ -153,7 +155,6 @@ struct wmma_type struct wmma_type + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { if constexpr(wave_size == 32) { - intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); } } }; - template struct wmma_type::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); } else if constexpr(wave_size == 64) { - intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); } } }; -#endif - template struct wmma_type + bool TransposeC = false, + bool AssemblyBackend = false> struct WmmaGemm { static constexpr auto I0 = Number<0>{}; @@ -369,14 +366,14 @@ struct WmmaGemm static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - using CIndex4D = MultiIndex<4>; + using CIndex3D = MultiIndex<3>; __host__ __device__ constexpr WmmaGemm() { static_assert(NPerWmma == 16 && MPerWmma == 16, "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); - static_assert(KPack == wmma_instr.k_per_wmma, "KPack should be k_per_wmma"); + static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma"); } // WMMA output supporting C = A * B @@ -421,9 +418,49 @@ struct WmmaGemm Sequence<5>{})); } + // Transposed WMMA Output C' = B' * A' + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_pass_through_transform(Number{}), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_unmerge_transform(make_tuple(Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + } + __device__ static constexpr index_t GetRegSizePerWmma() { - return wmma_instr.num_acc_vgprs_per_wave; + return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number; } __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } @@ -449,14 +486,16 @@ struct WmmaGemm , "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " "(int8, int32) or (int4, int32)!"); - if constexpr(!TransposeC) - { - wmma_instr.template run(p_a_wave, p_b_wave, p_c_thread); - } - else - { - wmma_instr.template run(p_b_wave, p_a_wave, p_c_thread); - } + static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + wmma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + wmma_instr.template run(p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); } __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } @@ -477,12 +516,12 @@ struct WmmaGemm __host__ __device__ static auto CalculateAThreadOriginDataIndex() { - return GetSwizzledLaneIdLow(); + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); } __host__ __device__ static auto CalculateBThreadOriginDataIndex() { - return GetLaneIdUnderSubGroup(); + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); } __device__ static CIndex GetBeginOfThreadBlk() @@ -493,6 +532,14 @@ struct WmmaGemm return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; } + __device__ static CIndex3D GetBeginOfThreadBlk3D() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId(); + + return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0}; + } + static constexpr auto wmma = WmmaSelector{}; static constexpr auto wmma_instr = wmma.selected_wmma; @@ -500,7 +547,10 @@ struct WmmaGemm __host__ __device__ static constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() { - return make_tuple(I1, I1, Number{}); + return make_tuple(I1, + I1, + Number{}, + Number{}); } }; diff --git a/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp new file mode 100644 index 0000000000..56181d38c8 --- /dev/null +++ b/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +__host__ __device__ static auto +MakeGridDescriptorPair(const std::array& gs_ms_ns_lengths_vec, + const std::array& gs_ms_ns_strides_vec) +{ + // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + // { + // throw std::runtime_error("wrong! dimension must match input lengths"); + // } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + __host__ __device__ static auto MakeAGridDescriptorPair( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + __host__ __device__ static auto MakeAGridDescriptor_G_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + __host__ __device__ static auto MakeAGridDescriptor_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + const AGridDesc_M_K& a_grid_desc_m_k, + const WmmaK&, + const MRepeat&, + const MWaves&, + const MPerWmma&, + const AK1&) + { + const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock; + const auto K = a_grid_desc_m_k.GetLength(I1); + const auto AKWmma = K / WmmaK{}; + constexpr auto AKRow = 2; + constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{}; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform( + make_tuple(AKWmma, Number{}, Number{}, AK1{})), + make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B (alias of B0) + // + __host__ __device__ static auto MakeB0GridDescriptorPair( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeB0GridDescriptor_G_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + __host__ __device__ static auto MakeB0GridDescriptor_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + const BGridDesc_L_K& b_grid_desc_l_k, + const WmmaK&, + const LRepeat&, + const LWaves&, + const LPerWmma&, + const BK1&) + { + const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock; + const auto K = b_grid_desc_l_k.GetLength(I1); + const auto BKWmma = K / WmmaK{}; + constexpr auto BKRow = 2; + constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{}; + + return transform_tensor_descriptor( + b_grid_desc_l_k, + make_tuple(make_unmerge_transform( + make_tuple(BKWmma, Number{}, Number{}, BK1{})), + make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B1 + // + __host__ __device__ static auto MakeB1GridDescriptorPair( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + __host__ __device__ static auto MakeB1GridDescriptor_G_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + __host__ __device__ static auto MakeB1GridDescriptor_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + const BGridDesc_N_L& b_grid_desc_n_l, + const WmmaL&, + const NRepeat&, + const NWaves&, + const NPerWmma&, + const BL1&) + { + const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock; + const auto L = b_grid_desc_n_l.GetLength(I1); + const auto BLWmma = L / WmmaL{}; + constexpr auto BLRow = 2; + constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{}; + + return transform_tensor_descriptor( + b_grid_desc_n_l, + make_tuple(make_unmerge_transform( + make_tuple(BLWmma, Number{}, Number{}, BL1{})), + make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // C + // + __host__ __device__ static auto MakeCGridDescriptorPair( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeCGridDescriptor_G_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + __host__ __device__ static auto MakeCGridDescriptor_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 2ea5419d09..678c55b95f 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -417,7 +417,8 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index 43baa817d3..5dc67a5ade 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -220,8 +220,8 @@ amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0 "0"(c0), "1"(c1)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); #endif } @@ -257,10 +257,10 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a, "2"(c2), "3"(c3)); #else - c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); - c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); - c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); - c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); #endif } @@ -355,17 +355,5 @@ __device__ void amd_assembly_outer_product_1x4(int8x16_t a, c3); } -// Ranged input operand -__device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, float8_t& c) -{ -#if defined(__gfx11__) - asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); -#else - ignore = a; - ignore = b; - ignore = c; -#endif -} - } // namespace ck #endif diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 31ae71880a..4d6791b5a7 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -133,6 +133,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = uint8_t; + static constexpr index_t vector_size = 1; +}; + #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 template <> struct scalar_type @@ -1037,6 +1044,14 @@ using bf8x8_t = typename vector_type::type; using bf8x16_t = typename vector_type::type; using bf8x32_t = typename vector_type::type; using bf8x64_t = typename vector_type::type; +// u8 +// i8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; template struct NumericLimits diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index dbac1f0c85..be74b1fdc1 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -99,6 +99,63 @@ inline __host__ __device__ constexpr bhalf_t type_convert(int8_ return type_convert(x_fp32); } +// Convert X to Y +template +__host__ __device__ constexpr Y type_convert_sp(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + return static_cast(x); +} + +template <> +inline __host__ __device__ constexpr int type_convert_sp(float x) +{ + union + { + float fp32; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr float type_convert_sp(int x) +{ + union + { + int int32; + float fp32; + } u = {x}; + + return u.fp32; +} + +template <> +inline __host__ __device__ constexpr int type_convert_sp(half_t x) +{ + union + { + half_t fp16; + int int32; + } u = {x}; + + return u.int32; +} + +template <> +inline __host__ __device__ constexpr half_t type_convert_sp(int x) +{ + union + { + int int32; + half_t fp16; + } u = {x}; + + return u.fp16; +} + // Declare a template function for fp8 conversion using SR template __host__ __device__ constexpr Y f8_convert_sr(X x); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index a1b1e0d91b..7a8e1d9a37 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -133,6 +133,252 @@ struct ReferenceBatchedGemm : public device::BaseOperator } }; +template +struct ReferenceBatchedGemm_MQA : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_1_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g0_g1_m_k_{a_g0_g1_m_k}, + b_g0_1_k_n_{b_g0_1_k_n}, + c_g0_g1_m_n_{c_g0_g1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g0_g1_m_k_; + const Tensor& b_g0_1_k_n_; + Tensor& c_g0_g1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm_MQA::Argument; + + float Run(const Argument& arg) + { + auto f_g0g1mk_g01kn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) { + const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k)); + arg.b_element_op_(v_b, arg.b_g0_1_k_n_(g0, 0, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_g0g1mk_g01kn_g0g1mn, + arg.c_g0_g1_m_n_.mDesc.GetLengths()[0], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[1], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[2], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_1_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_g0_g1_m_k, b_g0_1_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm_MQA" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +template +struct ReferenceBatchedGemm_GQA : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_g0_g1_m_k_{a_g0_g1_m_k}, + b_g0_gq_k_n_{b_g0_gq_k_n}, + c_g0_g1_m_n_{c_g0_g1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_g0_g1_m_k_; + const Tensor& b_g0_gq_k_n_; + Tensor& c_g0_g1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceBatchedGemm_GQA::Argument; + + float Run(const Argument& arg) + { + auto f_g0g1mk_g0gqkn_g0g1mn = [&](auto g0, auto g1, auto m, auto n) { + const int G1 = arg.a_g0_g1_m_k_.mDesc.GetLengths()[1]; + const int K = arg.a_g0_g1_m_k_.mDesc.GetLengths()[3]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + + arg.a_element_op_(v_a, arg.a_g0_g1_m_k_(g0, g1, m, k)); + arg.b_element_op_(v_b, arg.b_g0_gq_k_n_(g0, g1 * QueryGroupNumber / G1, k, n)); + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_g0_g1_m_n_(g0, g1, m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_g0g1mk_g0gqkn_g0g1mn, + arg.c_g0_g1_m_n_.mDesc.GetLengths()[0], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[1], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[2], + arg.c_g0_g1_m_n_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_g0_g1_m_k, + const Tensor& b_g0_gq_k_n, + Tensor& c_g0_g1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_g0_g1_m_k, b_g0_gq_k_n, c_g0_g1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceBatchedGemm_GQA" + << std::endl; + // clang-format on + + return str.str(); + } +}; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp new file mode 100644 index 0000000000..ac392f0906 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp @@ -0,0 +1,177 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferencefpAintBGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + scale_k_n_{scale_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + const Tensor& scale_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferencefpAintBGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + AccDataType v_acc = 0; + + for(int k = 0; k < K; ++k) + { + ADataType v_a; + BDataType v_b; + ScaleDataType v_scale; + ADataType v_converted_b; + + // use PassThrough instead of ConvertBF16RTN for reference calculation + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k)); + } + else + { + arg.a_element_op_(v_a, arg.a_m_k_(m, k)); + } + + // same for B matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_b, arg.b_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_b, arg.b_k_n_(k, n)); + } + + // same for scale matrix + if constexpr(is_same_v) + { + ck::tensor_operation::element_wise::PassThrough{}(v_scale, + arg.scale_k_n_(k, n)); + } + else + { + arg.b_element_op_(v_scale, arg.scale_k_n_(k, n)); + } + + v_converted_b = type_convert(v_b) * v_scale; + v_acc += ck::type_convert(v_a) * + ck::type_convert(v_converted_b); + } + + AccDataType v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = ck::type_convert(v_c); + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& scale_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, scale_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp index 31e5b72ea1..ee9d977096 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm.hpp @@ -384,6 +384,26 @@ void add_device_gemm_xdl_c_shuffle_f16_f8_f16_mk_nk_mn_instances( instances); #endif +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances); + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances); + template && is_same_v && is_same_v) @@ -493,6 +514,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs); add_device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances( op_ptrs); + add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -505,6 +527,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs); + add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -517,6 +540,7 @@ struct DeviceOperationInstanceFactory< add_device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs); #endif add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs); + add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances(op_ptrs); } } #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp index f925397832..4ea23ea1f9 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_instance.hpp @@ -54,36 +54,36 @@ template using device_grouped_conv_fwd_wmma_f16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>, // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> // clang-format on >; @@ -97,36 +97,36 @@ template using device_grouped_conv_fwd_wmma_i8_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Prefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //generic instance - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>, // blocksize=256 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>, // blocksize=128 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>, // blocksize=64 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>, // blocksize=32 - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGroupedConvFwdMultipleD_Wmma_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt index 3d243e3d56..e9cc1e854f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt @@ -111,6 +111,12 @@ list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_kn_mn_instance.cpp device_gemm_xdl_c_shuffle_fp8_fp8_fp8_km_nk_mn_instance.cpp) +list(APPEND GEMM_INSTANCES + device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp + device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp) + add_instance_library(device_gemm_instance ${GEMM_INSTANCES}) set(ENABLE_PIPELINE_V2_OPT) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..f3665eb8d8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_wmma_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_km_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..6726727e67 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_wmma_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_km_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..d526f17b56 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise|Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 2>, 8> +#if 0 + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> +#endif + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..eed856b6ca --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_wmma_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //######################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumPrefetch| Block| MPer| NPer| KPer| K1| MPer| NPer| M| N| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| + //######################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| | Size| Block| Block| Block| | WMMA| WMMA| Repeat| Repeat| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| MRepeat| ClusterLengths| ScalarPerVector| + //######################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerStore| PerStore| MBlock_MPerBlock| | + //######################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | + /* Prefetch 2, consume enormous vgpr resource*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 128, 128, 32, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 128, 128, 64, 64, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 64, 64, 32, 32, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + /* Prefetch 1, prefer larger KPerBlock value for better latency hiding*/ + // 8 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 160, 64, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, 8>, + // 4 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 256, 64, 64, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 256, 64, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 80, 64, 8, 16, 16, 1, 5, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 2>, 8>, + // 2 Waves + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 64, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8>, + // 1 Wave + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 32, 64, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmWmma_CShuffle< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 2>, 8> + // clang-format on + >; + +void add_device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& + instances) +{ + add_device_operation_instances(instances, device_gemm_wmma_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp index 73ea9cac07..dd055fabb8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp index 1f36113e62..f607484363 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp index 688c463369..accb2f80b6 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp @@ -36,32 +36,32 @@ static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecial // e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n]) using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances = std::tuple< // clang-format off - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp index 5319bd8605..6a23b70321 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_bilinear/device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp @@ -38,56 +38,56 @@ using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = st // clang-format off // no padding // N % 16 == 0 && K % 16 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmDefault, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 16 == 0 && K % 16 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 8 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, // M/N/K padding // N % 8 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 32, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 32, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 32, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 32, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>, // M/N/K padding // N % 1 == 0 && K % 8 == 0 - //################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, - DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> + //################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Prefetch| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Stage| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 128, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 64, 32, 32, 64, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>, + DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I32, I32, I8_Tuple, I8, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 1, 32, 16, 16, 64, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index d8bd0de692..93d5bd7422 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -1,16 +1,18 @@ -add_instance_library(device_grouped_conv2d_bwd_data_instance - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp - xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +add_instance_library( + device_grouped_conv2d_bwd_data_instance + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp - wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp) + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp +) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 1542d611f7..2715a8cf21 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -17,21 +17,21 @@ add_instance_library(device_grouped_conv2d_fwd_instance dl/device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp # WMMA # GNHWC, GKYXC, GNHWK - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp - # NHWGC, GKYXC, NHWGK - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_oddc_instance.cpp + ## NHWGC, GKYXC, NHWGK + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_f16_oddc_instance.cpp + wmma/device_grouped_conv2d_fwd_wmma_nhwgc_gkyxc_nhwgk_i8_oddc_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index bada661028..540ce3410b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -22,7 +22,8 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp - wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp) + wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp +) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/test/grouped_convnd_bwd_data/CMakeLists.txt b/test/grouped_convnd_bwd_data/CMakeLists.txt index 9773e5a9c6..305c568ee9 100644 --- a/test/grouped_convnd_bwd_data/CMakeLists.txt +++ b/test/grouped_convnd_bwd_data/CMakeLists.txt @@ -1,5 +1,5 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0) diff --git a/test/grouped_convnd_bwd_weight/CMakeLists.txt b/test/grouped_convnd_bwd_weight/CMakeLists.txt index b167943c97..d7d6f8a3d6 100644 --- a/test/grouped_convnd_bwd_weight/CMakeLists.txt +++ b/test/grouped_convnd_bwd_weight/CMakeLists.txt @@ -1,5 +1,5 @@ list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942) -list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102) +list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102 gfx1103) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) From 42fc8eddd21f5725881f8f503f2cb5724c935cb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Sat, 9 Mar 2024 02:13:03 +0100 Subject: [PATCH 24/36] Fix warnings during wrapper docs generation (#1192) * Fix warnings during wrapper docs generation * Fixes --- docs/conf.py | 2 ++ docs/wrapper.rst | 14 +++++++------- include/ck/wrapper/layout.hpp | 9 +++++++++ include/ck/wrapper/operations/copy.hpp | 3 +++ include/ck/wrapper/operations/gemm.hpp | 6 ++++++ include/ck/wrapper/tensor.hpp | 9 +++++++++ .../wrapper/traits/blockwise_gemm_xdl_traits.hpp | 3 +++ include/ck/wrapper/utils/kernel_utils.hpp | 3 +++ include/ck/wrapper/utils/layout_utils.hpp | 5 ++++- include/ck/wrapper/utils/tensor_partition.hpp | 6 ++++++ include/ck/wrapper/utils/tensor_utils.hpp | 5 ++++- 11 files changed, 56 insertions(+), 9 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index e441ff1ced..e8617a09ef 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -45,3 +45,5 @@ for sphinx_var in ROCmDocs.SPHINX_VARS: extensions += ['sphinxcontrib.bibtex'] bibtex_bibfiles = ['refs.bib'] + +cpp_id_attributes = ["__global__", "__device__", "__host__"] diff --git a/docs/wrapper.rst b/docs/wrapper.rst index 39e2fd0bbd..190fbcd445 100644 --- a/docs/wrapper.rst +++ b/docs/wrapper.rst @@ -64,31 +64,31 @@ Advanced examples: Layout ------------------------------------- -.. doxygenstruct:: ck::wrapper::Layout +.. doxygenstruct:: Layout ------------------------------------- Layout helpers ------------------------------------- -.. doxygenfile:: layout_utils.hpp +.. doxygenfile:: include/ck/wrapper/utils/layout_utils.hpp ------------------------------------- Tensor ------------------------------------- -.. doxygenstruct:: ck::wrapper::Tensor +.. doxygenstruct:: Tensor ------------------------------------- Tensor helpers ------------------------------------- -.. doxygenfile:: tensor_utils.hpp +.. doxygenfile:: include/ck/wrapper/utils/tensor_utils.hpp -.. doxygenfile:: tensor_partition.hpp +.. doxygenfile:: include/ck/wrapper/utils/tensor_partition.hpp ------------------------------------- Operations ------------------------------------- -.. doxygenfile:: copy.hpp -.. doxygenfile:: gemm.hpp +.. doxygenfile:: include/ck/wrapper/operations/copy.hpp +.. doxygenfile:: include/ck/wrapper/operations/gemm.hpp diff --git a/include/ck/wrapper/layout.hpp b/include/ck/wrapper/layout.hpp index 71c512e136..5cd1f614e6 100644 --- a/include/ck/wrapper/layout.hpp +++ b/include/ck/wrapper/layout.hpp @@ -5,8 +5,11 @@ #include "ck/wrapper/utils/layout_utils.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond /** * \brief Layout wrapper that performs the tensor descriptor logic. @@ -19,6 +22,8 @@ namespace wrapper { template struct Layout { + // Disable from doxygen docs generation + /// @cond INTERNAL private: static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -246,6 +251,7 @@ struct Layout using Descriptor1dType = remove_cvref_t; using DefaultIdxsTupleType = remove_cvref_t; + /// @endcond public: using LayoutShape = Shape; @@ -457,6 +463,8 @@ struct Layout return unrolled_descriptor_; } + // Disable from doxygen docs generation + /// @cond INTERNAL private: // All dimensions are unrolled UnrolledDescriptorType unrolled_descriptor_; @@ -469,6 +477,7 @@ struct Layout // Descriptor1dType lengths: (8) // MergedNestsDescriptorType lengths: (4, 2) const Shape shape_; + /// @endcond }; } // namespace wrapper diff --git a/include/ck/wrapper/operations/copy.hpp b/include/ck/wrapper/operations/copy.hpp index 5f64031ebe..e8a919fdda 100644 --- a/include/ck/wrapper/operations/copy.hpp +++ b/include/ck/wrapper/operations/copy.hpp @@ -12,8 +12,11 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond /** * \brief Perform optimized copy between two tensors partitions (threadwise copy). diff --git a/include/ck/wrapper/operations/gemm.hpp b/include/ck/wrapper/operations/gemm.hpp index e41cd5bd8a..42a70239ad 100644 --- a/include/ck/wrapper/operations/gemm.hpp +++ b/include/ck/wrapper/operations/gemm.hpp @@ -9,9 +9,14 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond +// Disable from doxygen docs generation +/// @cond INTERNAL namespace { namespace detail { /** @@ -45,6 +50,7 @@ __device__ constexpr auto GetBlockDescriptor() } // namespace detail } // namespace +/// @endcond /** * \brief Perform blockwise gemm xdl on tensors stored in lds. Result will be diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 6946e79ea4..8dabb58451 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -7,9 +7,14 @@ #include "utils/tensor_partition.hpp" #include "utils/layout_utils.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond +// Disable from doxygen docs generation +/// @cond INTERNAL namespace { namespace detail { /** @@ -189,6 +194,7 @@ __host__ __device__ constexpr auto GenerateSlicedDescriptor(const Tuple& } } // namespace detail } // namespace +/// @endcond /** * \brief Tensor wrapper that performs static and dynamic buffer logic. @@ -394,6 +400,8 @@ struct Tensor } private: + // Disable from doxygen docs generation + /// @cond INTERNAL using DynamicBufferType = DynamicBuffer struct Layout; diff --git a/include/ck/wrapper/utils/tensor_partition.hpp b/include/ck/wrapper/utils/tensor_partition.hpp index 141e0a58e5..69fd502d63 100644 --- a/include/ck/wrapper/utils/tensor_partition.hpp +++ b/include/ck/wrapper/utils/tensor_partition.hpp @@ -9,9 +9,14 @@ #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_description/cluster_descriptor.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond +// Disable from doxygen docs generation +/// @cond INTERNAL namespace { namespace detail { @@ -236,6 +241,7 @@ __host__ __device__ constexpr auto CalculateThreadMultiIdx( } } // namespace detail } // namespace +/// @endcond /** * \brief Create local partition for thread (At now only packed partition diff --git a/include/ck/wrapper/utils/tensor_utils.hpp b/include/ck/wrapper/utils/tensor_utils.hpp index ee9e438a40..ccab99fac3 100644 --- a/include/ck/wrapper/utils/tensor_utils.hpp +++ b/include/ck/wrapper/utils/tensor_utils.hpp @@ -13,8 +13,11 @@ #include "ck/utility/amd_address_space.hpp" #include "ck/utility/multi_index.hpp" +// Disable from doxygen docs generation +/// @cond INTERNAL namespace ck { namespace wrapper { +/// @endcond /** * \brief Memory type, allowed members: @@ -27,7 +30,7 @@ namespace wrapper { using MemoryTypeEnum = AddressSpaceEnum; // Disable from doxygen docs generation -/// @cond +/// @cond INTERNAL // forward declarations template struct Layout; From 8e97e85ac6cdb71903d3ac46a7e82f8350eb0ce5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 Mar 2024 08:21:14 -0700 Subject: [PATCH 25/36] Bump rocm-docs-core from 0.35.1 to 0.36.0 in /docs/sphinx (#1194) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.35.1 to 0.36.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.35.1...v0.36.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index 93c15a2160..b3c8267736 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.35.1 +rocm-docs-core==0.36.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 8faeac85db..ba1d7da441 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -113,7 +113,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.35.1 +rocm-docs-core==0.36.0 # via -r requirements.in six==1.16.0 # via From 12441af014d8c865a0254efea3f82a07bdf58b4f Mon Sep 17 00:00:00 2001 From: randyh62 <42045079+randyh62@users.noreply.github.com> Date: Tue, 12 Mar 2024 18:25:48 -0700 Subject: [PATCH 26/36] Doc reorg2 (#1189) * doc_reorg2 updated TOC * doc_reorg2 updates * fix conflicts, add grid --- docs/{ => conceptual}/what-is-ck.rst | 4 +-- docs/index.rst | 25 +++++++------- docs/{ => install}/dockerhub.rst | 0 docs/license.md | 2 -- docs/license.rst | 11 +++++++ docs/{ => reference}/API_Reference_Guide.rst | 0 .../Supported_Primitives_Guide.rst | 0 docs/{ => reference}/wrapper.rst | 0 docs/sphinx/_toc.yml.in | 33 ++++++++++++++----- docs/{ => tutorial}/tutorial_hello_world.rst | 0 10 files changed, 49 insertions(+), 26 deletions(-) rename docs/{ => conceptual}/what-is-ck.rst (94%) rename docs/{ => install}/dockerhub.rst (100%) delete mode 100644 docs/license.md create mode 100644 docs/license.rst rename docs/{ => reference}/API_Reference_Guide.rst (100%) rename docs/{ => reference}/Supported_Primitives_Guide.rst (100%) rename docs/{ => reference}/wrapper.rst (100%) rename docs/{ => tutorial}/tutorial_hello_world.rst (100%) diff --git a/docs/what-is-ck.rst b/docs/conceptual/what-is-ck.rst similarity index 94% rename from docs/what-is-ck.rst rename to docs/conceptual/what-is-ck.rst index f0b51c48f8..36785fc6ca 100644 --- a/docs/what-is-ck.rst +++ b/docs/conceptual/what-is-ck.rst @@ -20,7 +20,7 @@ CK utilizes two concepts to achieve performance portability and code maintainabi * Algorithm complexity reduction for complex ML operators using an innovative technique called "Tensor Coordinate Transformation". -.. image:: data/ck_component.png +.. image:: ../data/ck_component.png :alt: CK Components @@ -36,6 +36,6 @@ The CK library is structured into 4 layers: It also includes a simple wrapper component used to perform tensor transform operations more easily and with fewer lines of code. -.. image:: data/ck_layer.png +.. image:: ../data/ck_layer.png :alt: CK Layers \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 8ae4ce3a22..55c80b8edf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -12,27 +12,26 @@ The Composable Kernel (CK) library provides a programming model for writing perf The CK documentation is structured as follows: -.. card:: Conceptual +.. grid:: 2 + :gutter: 3 - * :ref:`what-is-ck` + .. grid-item-card:: Installation -.. card:: Installation + * :ref:`docker-hub` - * :ref:`docker-hub` + .. grid-item-card:: Conceptual -.. card:: Tutorial + * :ref:`what-is-ck` - * :ref:`hello-world` + .. grid-item-card:: API reference -.. card:: API reference + * :ref:`supported-primitives` + * :ref:`api-reference` + * :ref:`wrapper` - * :ref:`supported-primitives` - * :ref:`api-reference` - * :ref:`wrapper` + .. grid-item-card:: Tutorial -.. card:: Contributing to CK - - * :ref:`contributing-to` + * :ref:`hello-world` To contribute to the documentation refer to `Contributing to ROCm `_. diff --git a/docs/dockerhub.rst b/docs/install/dockerhub.rst similarity index 100% rename from docs/dockerhub.rst rename to docs/install/dockerhub.rst diff --git a/docs/license.md b/docs/license.md deleted file mode 100644 index 43e471da0e..0000000000 --- a/docs/license.md +++ /dev/null @@ -1,2 +0,0 @@ -```{include} ../LICENSE.md -``` diff --git a/docs/license.rst b/docs/license.rst new file mode 100644 index 0000000000..1e5389ccc1 --- /dev/null +++ b/docs/license.rst @@ -0,0 +1,11 @@ +.. meta:: + :description: Composable Kernel documentation and API reference library + :keywords: composable kernel, CK, ROCm, API, documentation + +.. _license: + +******************************************************************** +License +******************************************************************** + +.. include:: ../LICENSE \ No newline at end of file diff --git a/docs/API_Reference_Guide.rst b/docs/reference/API_Reference_Guide.rst similarity index 100% rename from docs/API_Reference_Guide.rst rename to docs/reference/API_Reference_Guide.rst diff --git a/docs/Supported_Primitives_Guide.rst b/docs/reference/Supported_Primitives_Guide.rst similarity index 100% rename from docs/Supported_Primitives_Guide.rst rename to docs/reference/Supported_Primitives_Guide.rst diff --git a/docs/wrapper.rst b/docs/reference/wrapper.rst similarity index 100% rename from docs/wrapper.rst rename to docs/reference/wrapper.rst diff --git a/docs/sphinx/_toc.yml.in b/docs/sphinx/_toc.yml.in index 5780674624..533b81cd39 100644 --- a/docs/sphinx/_toc.yml.in +++ b/docs/sphinx/_toc.yml.in @@ -2,20 +2,35 @@ defaults: numbered: False root: index subtrees: -- entries: - - file: what-is-ck.rst + +- caption: Conceptual + entries: + - file: conceptual/what-is-ck.rst title: What is Composable Kernel? - - file: dockerhub.rst + +- caption: Install + entries: + - file: install/dockerhub.rst title: Docker Hub - - file: tutorial_hello_world.rst - title: Hello World Tutorial - - file: Supported_Primitives_Guide.rst + +- caption: CK API Reference + entries: + - file: reference/Supported_Primitives_Guide.rst title: Supported Primitives - - file: API_Reference_Guide.rst + - file: reference/API_Reference_Guide.rst title: API Reference - - file: wrapper.rst + - file: reference/wrapper.rst title: Wrapper + +- caption: Tutorial + entries: + - file: tutorial/tutorial_hello_world.rst + title: Hello World Tutorial + +- caption: About + entries: - file: Contributors_Guide.rst title: Contributing to CK - - file: license.md + - file: license.rst title: License + \ No newline at end of file diff --git a/docs/tutorial_hello_world.rst b/docs/tutorial/tutorial_hello_world.rst similarity index 100% rename from docs/tutorial_hello_world.rst rename to docs/tutorial/tutorial_hello_world.rst From 285251768e8026689411d330def1aa6a2329b544 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 13 Mar 2024 23:09:08 +0100 Subject: [PATCH 27/36] Add conv fwd/bwd data scale instances, extend bilinear instances (#1178) * Add conv fwd/bwd data scale instances * Fix cmake client example file --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../24_grouped_conv_activation/CMakeLists.txt | 8 + .../grouped_conv_bwd_data_scale_fp16.cpp | 216 +++++++++++++++++ .../grouped_conv_fwd_scale_fp16.cpp | 220 ++++++++++++++++++ .../element/unary_element_wise_operation.hpp | 6 + ...ed_conv_bwd_data_xdl_bilinear_instance.hpp | 131 ++++++----- ...ouped_conv_bwd_data_xdl_scale_instance.hpp | 149 ++++++++++++ ...grouped_conv_fwd_xdl_bilinear_instance.hpp | 112 ++++++--- ...ce_grouped_conv_fwd_xdl_scale_instance.hpp | 179 ++++++++++++++ ...rouped_convolution_backward_data_scale.hpp | 150 ++++++++++++ .../gpu/grouped_convolution_forward_scale.hpp | 175 ++++++++++++++ .../CMakeLists.txt | 6 + ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 50 ++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 50 ++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 50 ++++ .../grouped_conv3d_fwd_scale/CMakeLists.txt | 7 + ...ale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp | 55 +++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp | 54 +++++ ...cale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp | 54 +++++ ...ale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp | 54 +++++ 19 files changed, 1637 insertions(+), 89 deletions(-) create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp create mode 100644 client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp diff --git a/client_example/24_grouped_conv_activation/CMakeLists.txt b/client_example/24_grouped_conv_activation/CMakeLists.txt index b4895db891..074dcd9b97 100644 --- a/client_example/24_grouped_conv_activation/CMakeLists.txt +++ b/client_example/24_grouped_conv_activation/CMakeLists.txt @@ -38,3 +38,11 @@ target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE c add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) target_link_libraries(client_grouped_convnd_bwd_data_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) +# Fwd scale +add_executable(client_grouped_convnd_fwd_scale_fp16 + grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_fwd_scale_fp16 PRIVATE composable_kernel::device_conv_operations) +# Bwd data scale +add_executable(client_grouped_convnd_bwd_data_scale_fp16 + grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp) +target_link_libraries(client_grouped_convnd_bwd_data_scale_fp16 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp new file mode 100644 index 0000000000..e53ecc6c99 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -0,0 +1,216 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_bwd_data_scale() +{ + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * G * N * Di * Hi * Wi * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * G * N * Do * Ho * Wo * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + ck::Tuple<>, + InDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = std::size_t(2) * G * N * K * C * Do * Ho * Wo * Y * X + + 3 * G * N * Di * Hi * Wi * C; + std::size_t num_bytes = 2 * sizeof(InDataType) * G * N * Di * Hi * Wi * C + + sizeof(WeiDataType) * G * K * Z * Y * X * C + + sizeof(OutDataType) * G * N * Do * Ho * Wo * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + in.GetDeviceBuffer(), + out_lengths, + out_strides, + wei_lengths, + wei_strides, + {}, + {}, + in_lengths, + in_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_bwd_data_scale(); } diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp new file mode 100644 index 0000000000..11e69f5bb2 --- /dev/null +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +// Use std tuple instead of ck tuple to avoid clang +// implicit instantiation of undefined template error. +using DDataTypes = std::tuple; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 64; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Z = 3; // filter D +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Di = 14; // input D +static constexpr ck::index_t Hi = 14; // input H +static constexpr ck::index_t Wi = 14; // input W +static constexpr ck::index_t Do = 14; // output D +static constexpr ck::index_t Ho = 14; // output H +static constexpr ck::index_t Wo = 14; // output W + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +int execute_conv_fwd_scale() +{ + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space. + // However, CK's API only accepts lengths and strides with order of GNCDHW/GKCZYX/GNKDHW. + // Hence, we need to adjust the order of strides. + std::array in_lengths{G, N, C, Di, Hi, Wi}; + std::array in_strides{ + C, Di * Hi * Wi * G * C, 1, Hi * Wi * G * C, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Z, Y, X}; + std::array wei_strides{ + K * Z * Y * X * C, Z * Y * X * C, 1, Y * X * C, X * C, C}; + std::array out_lengths{G, N, K, Do, Ho, Wo}; + std::array out_strides{ + K, Do * Ho * Wo * G * K, 1, Ho * Wo * G * K, Wo * G * K, G * K}; + // Logical broadcast bias (we have to pass bias lengths in the same format as output - GNKDHW) + std::array bias_lengths{G, 1, K, 1, 1, 1}; + std::array bias_strides{K, 0, 1, 0, 0, 0}; + + std::array filter_strides{1, 1, 1}; + std::array filter_dilations{1, 1, 1}; + std::array input_left_pads{1, 1, 1}; + std::array input_right_pads{1, 1, 1}; + + SimpleDeviceMem in(sizeof(InDataType) * N * Di * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Z * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Do * Ho * Wo * G * K); + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + Scale>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + std::size_t flop = + std::size_t(2) * G * N * K * C * Ho * Wo * Y * X + 3 * N * Ho * Wo * G * K; + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + + sizeof(WeiDataType) * G * K * Y * X * C + + sizeof(OutDataType) * 2 * N * Ho * Wo * G * K; + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return EXIT_FAILURE; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + {}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + {}, + {}, + out_lengths, + out_strides, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + Scale{2.f}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return 0; +} + +int main() { return execute_conv_fwd_scale(); } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index c6d933893e..9c64ad4dfa 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -310,6 +310,12 @@ struct Scale y = scale_ * x; }; + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = ck::type_convert(scale_ * ck::type_convert(x)); + }; + float scale_; }; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp index 93a1ef2096..216b4e2fe7 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp @@ -18,8 +18,6 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; -using BF8 = ck::bf8_t; -using F8 = ck::f8_t; template using S = ck::Sequence; @@ -35,27 +33,42 @@ static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; // f16_f16_f32_f16 + template -using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_grouped_conv_bwd_data_xdl_bilinear_f16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, F16, F16, F32, F16, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> - // clang-format on - >; + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; // bf16_bf16_f32_bf16 template using device_grouped_conv_bwd_data_xdl_bilinear_bf16_instances = std::tuple< // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, ELayout, BF16, BF16, F32, BF16, Tuple, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -87,44 +113,35 @@ template -using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_grouped_conv_bwd_data_xdl_bilinear_f32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4> - // clang-format on - >; - -// f16_f16_f16_comp_f8 -template -using device_grouped_conv_bwd_data_xdl_bilinear_input_fp16_comp_bf8f8_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // generic instance - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - // instances for small conv.K and conv.C - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, ck::Tuple, ELayout, F16, F16, F32, F32, Tuple, F16, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8> - // clang-format on - >; + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple, F32, PassThrough, PassThrough, Bilinear, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp new file mode 100644 index 0000000000..d278b9a482 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +// f16_f16_f32_f16 + +template +using device_grouped_conv_bwd_data_xdl_scale_f16_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, F16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// bf16_bf16_f32_bf16 +template +using device_grouped_conv_bwd_data_xdl_scale_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, BF16, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +// f32_f32_f32_f32 +template +using device_grouped_conv_bwd_data_xdl_scale_f32_instances = + std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Tuple<>, F32, PassThrough, PassThrough, Scale, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp index 3c689990aa..1c3bfef8ce 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_bilinear_instance.hpp @@ -45,17 +45,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_bf16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -67,17 +79,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_f16_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; @@ -89,17 +113,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_f32_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> // clang-format on >; @@ -111,17 +147,29 @@ template using device_grouped_conv_fwd_xdl_bilinear_int8_instances = std::tuple< // clang-format off - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, // instances for small conv.K and conv.C - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Bilinear, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp new file mode 100644 index 0000000000..f4dfc8f773 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +template +using device_grouped_conv_fwd_xdl_scale_bf16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, BF16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_f16_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F16, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_f32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, F32, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +template +using device_grouped_conv_fwd_xdl_scale_int8_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, int8_t, PassThrough, PassThrough, Scale, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp new file mode 100644 index 0000000000..c25c492e40 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_scale.hpp @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +#ifdef CK_ENABLE_BF16 +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD< + NumDimSpatial, + OutLayout, + WeiLayout, + Tuple<>, + InLayout, + OutDataType, + WeiDataType, + Tuple<>, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale, + ComputeTypeA, + ComputeTypeB>> +{ + using DeviceOp = + DeviceGroupedConvBwdDataMultipleD, + InLayout, + OutDataType, + WeiDataType, + Tuple<>, + InDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Scale, + ComputeTypeA, + ComputeTypeB>; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP32 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + else if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + op_ptrs); + } +#endif + } + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp new file mode 100644 index 0000000000..c4bc1da57e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using Scale = ck::tensor_operation::element_wise::Scale; + +#ifdef CK_ENABLE_BF16 +// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP16 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_FP32 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +#ifdef CK_ENABLE_INT8 +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple<>, + int8_t, + PassThrough, + PassThrough, + Scale>>>& instances); +#endif + +template +struct DeviceOperationInstanceFactory> +{ + using DeviceOp = + DeviceGroupedConvFwdMultipleABD; + + static auto GetInstances() + { + std::vector> op_ptrs; + if constexpr(NumDimSpatial == 3 && is_same_v && + is_same_v && is_same_v && + DLayouts::Size() == 0) + { +#ifdef CK_ENABLE_FP32 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_FP16 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs); + } +#endif +#ifdef CK_ENABLE_BF16 + if constexpr(is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + op_ptrs); + } +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + op_ptrs); + } +#endif + } + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt new file mode 100644 index 0000000000..b7901a2815 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/CMakeLists.txt @@ -0,0 +1,6 @@ +set(GROUPED_CONV3D_BWD_DATA_BILINEAR + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp) + +add_instance_library(device_grouped_conv3d_bwd_data_scale_instance ${GROUPED_CONV3D_BWD_DATA_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..af94c0ce9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances( + std::vector, + NDHWGC, + BF16, + BF16, + Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_bf16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..cc8995320a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances( + std::vector, + NDHWGC, + F16, + F16, + Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f16_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..5ed7962bbc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data_scale/xdl/device_grouped_conv3d_bwd_data_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo, +// g, k] +void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances( + std::vector, + NDHWGC, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_scale_f32_instances<3, + NDHWGK, + GKZYXC, + Tuple<>, + NDHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt new file mode 100644 index 0000000000..45d270d554 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/CMakeLists.txt @@ -0,0 +1,7 @@ +set(GROUPED_CONV3D_FWD_BILINEAR + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp + xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp) + +add_instance_library(device_grouped_conv3d_fwd_scale_instance ${GROUPED_CONV3D_FWD_BILINEAR}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..acff3e81b3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_instances( + std::vector, + NDHWGK, + BF16, + BF16, + ck::Tuple<>, + BF16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_bf16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp new file mode 100644 index 0000000000..dacbfe6783 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances( + std::vector, + NDHWGK, + F16, + F16, + ck::Tuple<>, + F16, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f16_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp new file mode 100644 index 0000000000..9e2c1131ae --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances( + std::vector, + NDHWGK, + F32, + F32, + ck::Tuple<>, + F32, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_f32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp new file mode 100644 index 0000000000..f9cbf1c44e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_scale/xdl/device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_scale_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_int8_instances( + std::vector, + NDHWGK, + int8_t, + int8_t, + ck::Tuple<>, + int8_t, + PassThrough, + PassThrough, + Scale>>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_scale_int8_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From e626d5202ab826ee22b369d053ab9d42ab343cff Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Fri, 15 Mar 2024 11:50:03 -0500 Subject: [PATCH 28/36] Add instances for conv_scale with fp8 in/out (#1193) * Add fp8 conv instances and client example * Format * Add example * Update cmakelists * Add profiler mode * Format * Fix copyright headers --- client_example/16_convnd_fwd/CMakeLists.txt | 3 + .../16_convnd_fwd/conv3d_fwd_fp8.cpp | 46 ++++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + example/09_convnd_fwd/convnd_fwd_common.hpp | 91 ++++++++++++++++++- example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp | 81 +++++++++++++++++ ...ed_conv_fwd_bias_relu_add_wmma_example.inc | 7 +- .../device_grouped_conv_fwd_xdl_instance.hpp | 38 +++++++- .../gpu/grouped_convolution_forward.hpp | 23 ++++- .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 + ..._xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp | 53 +++++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 11 ++- 11 files changed, 349 insertions(+), 10 deletions(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index 5279e3dfcf..e2797415ef 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -7,6 +7,9 @@ endif() if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp) target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations) + + add_executable(client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp) + target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) endif() if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp new file mode 100644 index 0000000000..2506e29e0e --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index f9903bfe03..a3f63350f4 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -6,6 +6,7 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) + add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index 109b8f9ee3..b0fd6a382a 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -27,6 +27,88 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + template (), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp new file mode 100644 index 0000000000..ef130148bc --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::f8_t; +using WeiDataType = ck::f8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeDataType = ck::f8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeDataType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc index ca8746bb97..3248c5fa4d 100644 --- a/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc +++ b/example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. template struct LayoutSetting @@ -279,8 +279,9 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[]) switch(conv_param.num_dim_spatial_) { // case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param); - case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); - // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); + case 2: + return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param); + // case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param); } return false; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 56b362eb9b..e6040e0d9e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -214,6 +214,42 @@ using device_grouped_conv_fwd_xdl_f16_comp_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_FP8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, F8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 1be5c324c6..7d3071c171 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -727,6 +727,21 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance PassThrough, PassThrough, F8>>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP32 @@ -1137,6 +1152,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 540ce3410b..998c1a51a9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -30,4 +30,9 @@ if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp) endif() +if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp new file mode 100644 index 0000000000..48ec4397bc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index d0b424cde6..7dff5bf5ce 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include #include @@ -23,6 +23,7 @@ enum struct ConvDataType F16_F16_F16, // 1 BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 }; #define OP_NAME "grouped_conv_fwd" @@ -36,7 +37,8 @@ static void print_helper_msg() << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" << " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight bf16, Output bf16\n" - << " 3: Input int8, Weight int8, Output int8)\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -79,6 +81,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using F16 = ck::half_t; using BF16 = ck::bhalf_t; using INT8 = int8_t; + using F8 = ck::f8_t; // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -250,6 +253,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F8_F8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; From bdcd037428ac356e5b77271b7b6669c5c2d9548a Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 18 Mar 2024 09:48:29 -0700 Subject: [PATCH 29/36] Re-enable the performance tracking in CI. (#1203) * test CK with rocm6.1 RC2 * add docker credentials for pull * update the performance db name * use environment variable for db name * add rocm-llvm-dev package to ck docker * turn off verification for daily performance runs * do not stash ckProfiler on MI300 node * add processing of mixed gemms to qa, fix parsing of splitk gemm logs * fix the splitk gemm log file name * turn the timing on for splitk gemm performance --- Dockerfile | 19 ++++++----- Jenkinsfile | 47 +++++++++++++++------------- script/process_perf_data.py | 9 ++++-- script/run_full_performance_tests.sh | 26 +++++---------- 4 files changed, 52 insertions(+), 49 deletions(-) diff --git a/Dockerfile b/Dockerfile index 38f234943c..e3e791729e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,17 +16,17 @@ RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg -RUN if [ "$ROCMVERSION" != "6.0.1" ]; then \ +RUN if [ "$ROCMVERSION" != "6.1" ]; then \ sh -c "wget https://repo.radeon.com/amdgpu-install/6.0/ubuntu/focal/amdgpu-install_6.0.60000-1_all.deb --no-check-certificate" && \ apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated ./amdgpu-install_6.0.60000-1_all.deb && \ wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ - elif [ "$ROCMVERSION" = "6.0.1" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.0-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.0-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.0.1 rel-95 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=1704947; \ + elif [ "$ROCMVERSION" = "6.1" ] && [ "$compiler_version" = "rc2" ]; then \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.1-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.1-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.1 rel-48 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=1736298; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" @@ -41,6 +41,7 @@ chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} # Install dependencies +# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ cmake \ @@ -60,6 +61,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python3-dev \ python3-pip \ redis \ + rocm-llvm-dev \ sshpass \ stunnel \ software-properties-common \ @@ -73,6 +75,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +# Update the cmake to version 3.27.5 +RUN pip install --upgrade cmake==3.27.5 + #Install latest ccache RUN git clone https://github.com/ccache/ccache.git && \ cd ccache && mkdir build && cd build && cmake .. && make install @@ -82,8 +87,6 @@ RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releas RUN gunzip /usr/local/bin/ninja.gz RUN chmod a+x /usr/local/bin/ninja RUN git clone https://github.com/nico/ninjatracing.git -# Update the cmake to the latest version -RUN pip install --upgrade cmake==3.27.5 #Install latest cppcheck RUN git clone https://github.com/danmar/cppcheck.git && \ diff --git a/Jenkinsfile b/Jenkinsfile index abecb76408..e60bae2b65 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -38,7 +38,7 @@ def getDockerImageName(){ img = "${params.USE_CUSTOM_DOCKER}" } else{ - if (params.ROCMVERSION != "6.0.1"){ + if (params.ROCMVERSION != "6.1"){ if (params.COMPILER_VERSION == "") { img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}" } @@ -117,7 +117,9 @@ def getDockerImage(Map conf=[:]){ { echo "Pulling down image: ${image}" retimage = docker.image("${image}") - retimage.pull() + withDockerRegistry([ credentialsId: "docker_test_cred", url: "" ]) { + retimage.pull() + } } catch(Exception ex) { @@ -406,7 +408,7 @@ def runCKProfiler(Map conf=[:]){ dir("script"){ if (params.RUN_FULL_QA){ - sh "./run_full_performance_tests.sh 1 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" + sh "./run_full_performance_tests.sh 0 QA_${params.COMPILER_VERSION} ${env.BRANCH_NAME} ${NODE_NAME}" archiveArtifacts "perf_gemm.log" archiveArtifacts "perf_resnet50_N256.log" archiveArtifacts "perf_resnet50_N4.log" @@ -416,9 +418,9 @@ def runCKProfiler(Map conf=[:]){ archiveArtifacts "perf_conv_bwd_data.log" archiveArtifacts "perf_gemm_bilinear.log" archiveArtifacts "perf_reduction.log" - archiveArtifacts "perf_splitK_gemm_verify.log" archiveArtifacts "perf_splitK_gemm.log" archiveArtifacts "perf_onnx_gemm.log" + archiveArtifacts "perf_mixed_gemm.log" // stash perf files to master stash name: "perf_gemm.log" stash name: "perf_resnet50_N256.log" @@ -431,6 +433,7 @@ def runCKProfiler(Map conf=[:]){ stash name: "perf_reduction.log" stash name: "perf_splitK_gemm.log" stash name: "perf_onnx_gemm.log" + stash name: "perf_mixed_gemm.log" //we will process results on the master node } else{ @@ -493,9 +496,6 @@ def Build_CK(Map conf=[:]){ def variant = env.STAGE_NAME def retimage - def navi_node = 0 - def mi300_node = 0 - gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { try { (retimage, image) = getDockerImage(conf) @@ -508,14 +508,6 @@ def Build_CK(Map conf=[:]){ else{ echo "GPU is OK" } - if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ - navi_node = 1 - echo "This is a Navi node" - } - if ( runShell('grep -n "gfx942" rocminfo.log') ){ - mi300_node = 1 - echo "This is MI300 node" - } } } } @@ -526,15 +518,27 @@ def Build_CK(Map conf=[:]){ withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') { timeout(time: 24, unit: 'HOURS') { + //check whether running on Navi or MI300 node + def navi_node = 0 + def mi300_node = 0 + sh 'rocminfo | tee rocminfo.log' + if ( runShell('grep -n "gfx1030" rocminfo.log') || runShell('grep -n "gfx1101" rocminfo.log') ){ + navi_node = 1 + echo "This is a Navi node" + } + if ( runShell('grep -n "gfx942" rocminfo.log') ){ + mi300_node = 1 + echo "This is MI300 node" + } cmake_build(conf) dir("build"){ //run tests and examples sh 'make -j check' - if (navi_node == 0 ){ + if (params.RUN_PERFORMANCE_TESTS && navi_node == 0 && mi300_node == 0 ){ //we only need the ckProfiler to run the performance tests, so we pack and stash it - //do not stash profiler on Navi nodes + //do not stash profiler on Navi or MI300 nodes sh 'tar -zcvf ckProfiler.tar.gz bin/ckProfiler' - stash "ckProfiler.tar.gz" + stash name: "ckProfiler.tar.gz" } if (params.RUN_FULL_QA && mi300_node == 0 ){ // build deb packages for all MI100/200/300 targets and prepare to export @@ -542,7 +546,7 @@ def Build_CK(Map conf=[:]){ archiveArtifacts artifacts: 'composablekernel-ckprofiler_*.deb' archiveArtifacts artifacts: 'composablekernel-tests_*.deb' sh 'mv composablekernel-ckprofiler_*.deb ckprofiler_0.2.0_amd64.deb' - stash "ckprofiler_0.2.0_amd64.deb" + stash name: "ckprofiler_0.2.0_amd64.deb" } } if (params.hipTensor_test && navi_node == 0 ){ @@ -629,6 +633,7 @@ def process_results(Map conf=[:]){ unstash "perf_reduction.log" unstash "perf_splitK_gemm.log" unstash "perf_onnx_gemm.log" + unstash "perf_mixed_gemm.log" sh "./process_qa_data.sh" unstash "ckprofiler_0.2.0_amd64.deb" sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" @@ -716,8 +721,8 @@ pipeline { description: "Run the cppcheck static analysis (default: OFF)") booleanParam( name: "RUN_PERFORMANCE_TESTS", - defaultValue: false, - description: "Run the performance tests (default: OFF)") + defaultValue: true, + description: "Run the performance tests (default: ON)") booleanParam( name: "RUN_CODEGEN_TESTS", defaultValue: true, diff --git a/script/process_perf_data.py b/script/process_perf_data.py index d7e40569fd..2c46da8fd2 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -133,11 +133,16 @@ def parse_logfile(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[4]) - elif 'onnx_gemm' in logfile or 'splitK_gemm' in logfile or 'mixed_gemm' in logfile: + elif 'onnx_gemm' in logfile or 'mixed_gemm' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[33]) + elif 'splitK_gemm' in logfile: + for line in open(logfile): + if 'Best Perf' in line: + lst=line.split() + res.append(lst[36]) return res @@ -231,7 +236,7 @@ def main(): sql_hostname = '127.0.0.1' sql_username = os.environ["dbuser"] sql_password = os.environ["dbpassword"] - sql_main_database = 'miopen_perf' + sql_main_database = os.environ["ck_perf_db"] sql_port = 3306 ssh_host = os.environ["dbsship"] ssh_user = os.environ["dbsshuser"] diff --git a/script/run_full_performance_tests.sh b/script/run_full_performance_tests.sh index 90678389fa..01ac1b0a39 100755 --- a/script/run_full_performance_tests.sh +++ b/script/run_full_performance_tests.sh @@ -121,26 +121,16 @@ print_log_header $reduction_log $env_type $branch $host_name ./profile_reduce_no_index.sh $verify 2 10 --half 2>&1 | tee -a $reduction_log #run splitK_gemm tests, first correctness verification, then performance -export splitK_gemm_ver_log="perf_splitK_gemm_verify.log" -print_log_header $splitK_gemm_ver_log $env_type $branch $host_name -./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 0 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 0 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 1 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 2 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log -./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 0 4 2>&1 | tee -a $splitK_gemm_ver_log export splitK_gemm_log="perf_splitK_gemm.log" print_log_header $splitK_gemm_log $env_type $branch $host_name -./profile_splitK_gemm.sh gemm_splitk 0 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 0 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 0 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 1 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 2 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log -./profile_splitK_gemm.sh gemm_splitk 1 3 0 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 0 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 1 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 2 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 0 3 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 0 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 1 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 2 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log +./profile_splitK_gemm.sh gemm_splitk 1 3 $verify 1 0 1 4 2>&1 | tee -a $splitK_gemm_log #run ONNX gemm tests export onnx_log="perf_onnx_gemm.log" From 9e011bcd6e7735fbcd9045bbd7f2fb98df1446a0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 18 Mar 2024 10:16:45 -0700 Subject: [PATCH 30/36] update the changelog for ROCm6.1 release (#1205) * update the changelog for ROCm6.1 release * modifty the order of items in changelog, capitalize GEMMs --- CHANGELOG.md | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e3feed2df..fb2ba1975f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,20 +2,27 @@ Full documentation for Composable Kernel is not yet available. -## (Unreleased) CK - -### Fixes -None - -### Optimizations -None +## CK for ROCm 6.1.0 ### Additions -* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126, #1139) +* Added generic instances for GEMM XDL operations (#1161) +* Added gamma and beta parameters for the layernorm and groupnorm bwd operations (#1133) +* Introduced wrapper sublibrary (limited functionality). (#1071, #1098, #1108, #1126) +* Added an option to vary the number of warm-up cycles and iterations for ckProfiler (#1124) + +### Optimizations +* New performance optimizations for GEMM operations on MI200 and MI300 architectures (#1135) + +### Fixes +* Reduced the build time for most GPU architectures (#1084) +* Fixed some conversion issues for fp8 data type (#1099) ### Changes None +### Known issues +None + ## CK for ROCm 6.0.0 ### Fixes @@ -32,7 +39,7 @@ None * Grouped convolution support for small K and C (#822 #879 #897) * Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804) * Support for bf16/f32/f16 and NHWGC (2D and 3D) grouped convolution backward data (#757 #799) -* Support for Batched Gemm DL (#732) +* Support for Batched GEMM DL (#732) ### Changes * Changed the grouped convolution API to maintain consistency with other convolution kernels (#817) @@ -48,7 +55,7 @@ None ### Additions * New CMake flags: - * "DL_KERNELS"-* Must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances + * "DL_KERNELS"-* Must be set to "ON" in order to build the GEMM DL and batched_gemm_multi_d_dl instances * "DTYPES" -- Can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build an instance of the specified data types * "INSTANCES_ONLY" -- Only builds CK library and instances without tests, examples, or profiler * New feature: if GPU_TARGETS is not set in the CMake command line, CK will be built for all targets supported by the compiler From f52109531b539a9dc8f7f744a104e10558288946 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 19 Mar 2024 08:38:52 -0700 Subject: [PATCH 31/36] Fix a couple of docker issues. (#1206) * do not install sccache by default, only install rocm-llvm-dev for rocm6.1 * add sccache flag to docker build options --- Dockerfile | 18 ++++++++++++------ Jenkinsfile | 9 +++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index e3e791729e..cc8b1eadf2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,7 @@ ARG DEBIAN_FRONTEND=noninteractive ARG ROCMVERSION=6.0 ARG compiler_version="" ARG compiler_commit="" +ARG CK_SCCACHE="" RUN set -xe @@ -32,16 +33,18 @@ RUN if [ "$ROCMVERSION" != "6.1" ]; then \ RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" RUN amdgpu-install -y --usecase=rocm --no-dkms -## Sccache binary built from source for ROCm +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin -RUN mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ -curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ -chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi # Install dependencies -# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \ build-essential \ cmake \ @@ -61,7 +64,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- python3-dev \ python3-pip \ redis \ - rocm-llvm-dev \ sshpass \ stunnel \ software-properties-common \ @@ -75,6 +77,10 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- apt-get clean && \ rm -rf /var/lib/apt/lists/* +# hipTensor requires rocm-llvm-dev for rocm versions > 6.0.1 +RUN if [ "$ROCMVERSION" = "6.1" ]; then \ + sh -c "apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated rocm-llvm-dev"; \ + fi # Update the cmake to version 3.27.5 RUN pip install --upgrade cmake==3.27.5 diff --git a/Jenkinsfile b/Jenkinsfile index e60bae2b65..ec3cbd0e27 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -104,7 +104,7 @@ def getDockerImage(Map conf=[:]){ env.DOCKER_BUILDKIT=1 def prefixpath = conf.get("prefixpath", "/opt/rocm") def no_cache = conf.get("no_cache", false) - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if(no_cache) { dockerArgs = dockerArgs + " --no-cache " @@ -134,7 +134,7 @@ def buildDocker(install_prefix){ checkout scm def image_name = getDockerImageName() echo "Building Docker for ${image_name}" - def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg BUILDKIT_INLINE_CACHE=1 --build-arg PREFIX=${install_prefix} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " echo "Build Args: ${dockerArgs}" try{ @@ -311,7 +311,7 @@ def buildHipClangJob(Map conf=[:]){ if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } - def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " + def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " } @@ -367,9 +367,6 @@ def runCKProfiler(Map conf=[:]){ dockerOpts = dockerOpts + " --env HSA_XNACK=1 " } def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' " - if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline-open" || params.COMPILER_COMMIT != ""){ - dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' " - } def variant = env.STAGE_NAME def retimage From 9e5042691539ba6731158c5c7b83fff4a25f7715 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 09:28:03 -0600 Subject: [PATCH 32/36] Bump rocm-docs-core from 0.36.0 to 0.37.0 in /docs/sphinx (#1208) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.36.0 to 0.37.0. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.36.0...v0.37.0) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index b3c8267736..ae92cc6c10 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.36.0 +rocm-docs-core==0.37.0 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index ba1d7da441..43853dd3fa 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -96,9 +96,7 @@ pygments==2.15.0 # pydata-sphinx-theme # sphinx pyjwt[crypto]==2.6.0 - # via - # pygithub - # pyjwt + # via pygithub pynacl==1.5.0 # via pygithub pytz==2023.3.post1 @@ -113,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.36.0 +rocm-docs-core==0.37.0 # via -r requirements.in six==1.16.0 # via From fd0d093e78c18197a4f1b7dafdbc1e2438d28317 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Thu, 21 Mar 2024 13:57:34 -0500 Subject: [PATCH 33/36] Add instances for conv_scale with bf8 in / fp8 out (#1200) * Add bf8 conv fwd instances * Add example * Add profiler mode * Add client example * Fix copyright headers * Format --- client_example/16_convnd_fwd/CMakeLists.txt | 5 ++ .../16_convnd_fwd/conv3d_fwd_bf8.cpp | 46 +++++++++++ example/09_convnd_fwd/CMakeLists.txt | 1 + example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp | 81 +++++++++++++++++++ .../device_grouped_conv_fwd_xdl_instance.hpp | 40 +++++++++ .../gpu/grouped_convolution_forward.hpp | 24 ++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 ++ ..._xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp | 53 ++++++++++++ profiler/src/profile_grouped_conv_fwd.cpp | 9 ++- 9 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt index e2797415ef..e034c468d5 100644 --- a/client_example/16_convnd_fwd/CMakeLists.txt +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -12,6 +12,11 @@ if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES) target_link_libraries(client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations) endif() +if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES) + add_executable(client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp) + target_link_libraries(client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations) +endif() + if((DTYPES MATCHES "fp32") OR NOT DEFINED DTYPES) add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_conv_operations) diff --git a/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp b/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp new file mode 100644 index 0000000000..983e0d083c --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_bf8.cpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using OutDataType = ck::f8_t; + +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index a3f63350f4..195f1857ed 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -7,6 +7,7 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp8 convnd_fwd_xdl_fp8.cpp) + add_example_executable(example_convnd_fwd_xdl_bf8 convnd_fwd_xdl_bf8.cpp) # FIXME: re-enable this exampe as test when SWDEV-335738 is fixed add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp new file mode 100644 index 0000000000..0fc9e7b5dd --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_bf8.cpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +using InDataType = ck::bf8_t; +using WeiDataType = ck::bf8_t; +using AccDataType = float; +using CShuffleDataType = ck::f8_t; +using OutDataType = ck::f8_t; +using ComputeType = ck::bf8_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, + WeiLayout, + ck::Tuple<>, + OutLayout, + InDataType, + WeiDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 32, // KPerBlock + 8, // AK1 + 8, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, + 1, + S<1, 32, 1, 8>, + 8, + ComputeType>; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index e6040e0d9e..0f845ca1ed 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -17,6 +17,10 @@ namespace instance { using F8 = ck::f8_t; #endif +#ifdef CK_ENABLE_BF8 +using BF8 = ck::bf8_t; +#endif + using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; @@ -250,6 +254,42 @@ using device_grouped_conv_fwd_xdl_f8_instances = std::tuple< // clang-format on >; +template +using device_grouped_conv_fwd_xdl_bf8_instances = std::tuple< +// clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#ifdef CK_ENABLE_BF8 + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, BF8> +#endif + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 7d3071c171..b9712542a8 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -744,6 +744,23 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f8_instances( F8>>>& instances); #endif +#ifdef CK_ENABLE_BF8 +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances); +#endif + #ifdef CK_ENABLE_FP32 void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( std::vector && is_same_v && + is_same_v && is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances(op_ptrs); + } +#endif #ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 998c1a51a9..3825b92af4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -35,4 +35,9 @@ if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES) xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_fp8_instance.cpp) endif() +if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES) + list(APPEND GROUPED_CONV3D_FWD + xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp) +endif() + add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp new file mode 100644 index 0000000000..9f1ceae808 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instance.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf8_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_bf8_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index 7dff5bf5ce..1f72733729 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -24,6 +24,7 @@ enum struct ConvDataType BF16_BF16_BF16, // 2 INT8_INT8_INT8, // 3 F8_F8_F8, // 4 + BF8_BF8_F8, // 5 }; #define OP_NAME "grouped_conv_fwd" @@ -38,7 +39,8 @@ static void print_helper_msg() << " 1: Input fp16, Weight fp16, Output fp16\n" << " 2: Input bf16, Weight bf16, Output bf16\n" << " 3: Input int8, Weight int8, Output int8\n" - << " 4: Input fp8, Weight fp8, Output fp8)\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n" << "arg4: verification (0: no, 1: yes)\n" @@ -82,6 +84,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using BF16 = ck::bhalf_t; using INT8 = int8_t; using F8 = ck::f8_t; + using BF8 = ck::bf8_t; // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -257,6 +260,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F8{}, F8{}, F8{}); } + else if(data_type == ConvDataType::BF8_BF8_F8) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, BF8{}, F8{}); + } } std::cout << "this data_type & layout is not implemented" << std::endl; From 9c052804a75491865bab0fad49d059b6e4e98cdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 22 Mar 2024 10:40:43 +0100 Subject: [PATCH 34/36] Add elementwise with dynamic vector dim (#1198) * Add elementwise with dynamic vector dim * Reduce number of instaces * Fixes * Fixes --- .../elementwise_permute_4D_fp16.cpp | 25 +- .../elementwise_permute_4D_fp16_col.cpp | 54 +- .../elementwise_permute_4D_fp16_row.cpp | 53 +- .../elementwise_permute_4D_fp32_col.cpp | 54 +- .../elementwise_permute_4D_fp32_row.cpp | 53 +- ...hread_group_tensor_slice_transfer_v4r2.hpp | 193 +++++ ...e_elementwise_dynamic_vector_dims_impl.hpp | 422 +++++++++ ...idwise_elementwise_dynamic_vector_dims.hpp | 169 ++++ .../threadwise_tensor_slice_transfer_v3r2.hpp | 804 ++++++++++++++++++ .../gpu/permute_scale.hpp | 116 +-- .../device_permute_scale_instances.hpp | 179 +++- .../gpu/permute_scale/CMakeLists.txt | 18 +- ...evice_permute_scale_1d_fp16_instances.cpp} | 15 +- ...device_permute_scale_1d_fp32_instances.cpp | 24 + ...evice_permute_scale_2d_fp16_instances.cpp} | 15 +- ...device_permute_scale_2d_fp32_instances.cpp | 24 + ...evice_permute_scale_3d_fp16_instances.cpp} | 15 +- ...device_permute_scale_3d_fp32_instances.cpp | 24 + ...evice_permute_scale_4d_fp16_instances.cpp} | 15 +- ...device_permute_scale_4d_fp32_instances.cpp | 24 + ...evice_permute_scale_5d_fp16_instances.cpp} | 15 +- ...device_permute_scale_5d_fp32_instances.cpp | 24 + ...evice_permute_scale_6d_fp16_instances.cpp} | 15 +- ...device_permute_scale_6d_fp32_instances.cpp | 24 + .../profiler/profile_permute_scale_impl.hpp | 46 +- profiler/src/profile_permute_scale.cpp | 29 +- script/profile_permute_scale.sh | 43 + test/permute_scale/test_permute_scale.cpp | 24 +- 28 files changed, 2157 insertions(+), 359 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp create mode 100644 include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_1d_instances.cpp => device_permute_scale_1d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_2d_instances.cpp => device_permute_scale_2d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_3d_instances.cpp => device_permute_scale_3d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_4d_instances.cpp => device_permute_scale_4d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_5d_instances.cpp => device_permute_scale_5d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp rename library/src/tensor_operation_instance/gpu/permute_scale/{device_permute_scale_6d_instances.cpp => device_permute_scale_6d_fp16_instances.cpp} (56%) create mode 100644 library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp create mode 100755 script/profile_permute_scale.sh diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp index 8e9bc64ab6..1b28a901cb 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,15 +20,20 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // Elementwise op - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + PassThrough, // Elementwise + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq template void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp index 9d5fdc0cc7..f832601f07 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_col.cpp @@ -7,7 +7,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -21,26 +21,23 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<1>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t C = A_nchw.mDesc.GetLengths()[1]; @@ -51,11 +48,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, for(std::size_t c = 0; c < C; ++c) for(std::size_t n = 0; n < N; ++n) { - ADataType tmp_val; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor_b(tmp_val, a_val); - functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], - scale * tmp_val); + functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); } } @@ -104,14 +98,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -143,7 +131,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp index 7d215cef24..bae85f53c1 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp16_row.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,36 +20,31 @@ using F32 = float; using ADataType = F16; using BDataType = F16; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) { - ADataType tmp_val; auto a_val = A_nchw(n, c, h, w); - functor_b(tmp_val, a_val); - functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + functor(B_nhwc(n, h, w, c), a_val); } } @@ -86,14 +81,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -125,7 +114,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp index 69e411c59a..fe7acd3010 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_col.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,26 +20,23 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 1, // MPerThread - ck::Sequence<1>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<1>, // InScalarPerVectorSeq + ck::Sequence<1>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { std::size_t N = A_nchw.mDesc.GetLengths()[0]; std::size_t C = A_nchw.mDesc.GetLengths()[1]; @@ -50,11 +47,8 @@ void host_elementwise4D(HostTensorB& B_nhwc, for(std::size_t c = 0; c < C; ++c) for(std::size_t n = 0; n < N; ++n) { - ADataType tmp_val; auto a_val = A_nchw.mData[(n) + (c * N) + (h * C * N) + (w * H * C * N)]; - functor_b(tmp_val, a_val); - functor_a(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], - scale * tmp_val); + functor(B_nhwc.mData[(n) + (c * W * H * N) + (h * N) + (w * H * N)], a_val); } } @@ -104,14 +98,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -143,7 +131,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp index 69f40fe165..aebdb37d9b 100644 --- a/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp +++ b/example/44_elementwise_permute/elementwise_permute_4D_fp32_row.cpp @@ -6,7 +6,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/utility/algorithm.hpp" #include "ck/library/utility/check_err.hpp" @@ -20,36 +20,31 @@ using F32 = float; using ADataType = F32; using BDataType = F32; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; -using DeviceElementwisePermuteInstance = - ck::tensor_operation::device::DeviceElementwiseImpl, // InDataTypeTuple - ck::Tuple, // OutDataTypeTuple - PassThrough, // ElementwiseOp - UnaryOp, // UnaryOp - Scale, // Scalar - 4, // NumDim - 8, // MPerThread - ck::Sequence<8>, // InScalarPerVectorSeq - ck::Sequence<1>>; // OutScalarPerVectorSeq +using UnaryOp = ck::tensor_operation::element_wise::Scale; +using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl< + ck::Tuple, // InDataTypeTuple + ck::Tuple, // OutDataTypeTuple + UnaryOp, // UnaryOp + 4, // NumDim + 256, // BlockSize + 128, // M0PerBlock + 128, // M1PerBlock + 8, // M0PerThread + 8, // M1PerThread + ck::Sequence<1, 0>, // ThreadClusterArrangeOrder + ck::Sequence<8>, // InScalarPerVectorSeq + ck::Sequence<8>>; // OutScalarPerVectorSeq -template -void host_elementwise4D(HostTensorB& B_nhwc, - const HostTensorA& A_nchw, - FunctorA functor_a, - FunctorB functor_b, - float scale) +template +void host_elementwise4D(HostTensorB& B_nhwc, const HostTensorA& A_nchw, Functor functor) { for(std::size_t n = 0; n < A_nchw.mDesc.GetLengths()[0]; ++n) for(std::size_t c = 0; c < A_nchw.mDesc.GetLengths()[1]; ++c) for(std::size_t h = 0; h < A_nchw.mDesc.GetLengths()[2]; ++h) for(std::size_t w = 0; w < A_nchw.mDesc.GetLengths()[3]; ++w) { - ADataType tmp_val; auto a_val = A_nchw(n, c, h, w); - functor_b(tmp_val, a_val); - functor_a(B_nhwc(n, h, w, c), scale * tmp_val); + functor(B_nhwc(n, h, w, c), a_val); } } @@ -86,14 +81,8 @@ int main() ck::ranges::copy(nchw, ab_lengths.begin()); auto broadcastPermute = DeviceElementwisePermuteInstance{}; - auto argument = broadcastPermute.MakeArgumentPointer(ab_lengths, - {a_strides}, - {b_strides}, - input, - output, - PassThrough{}, - UnaryOp{}, - Scale{scale}); + auto argument = broadcastPermute.MakeArgumentPointer( + ab_lengths, {a_strides}, {b_strides}, input, output, UnaryOp{scale}); if(!broadcastPermute.IsSupportedArgument(argument.get())) { @@ -125,7 +114,7 @@ int main() { b_device_buf.FromDevice(b.mData.data()); Tensor host_b(nhwc); - host_elementwise4D(host_b, a, PassThrough{}, UnaryOp{}, scale); + host_elementwise4D(host_b, a, UnaryOp{scale}); pass &= ck::utils::check_err(b.mData, host_b.mData, "Error: Incorrect results b", 1e-3, 1e-3); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp new file mode 100644 index 0000000000..aa1f7c5735 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r2 +{ + static constexpr index_t nDim = + remove_reference_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + + { + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffer& src_bufs, + const DstDescs& dst_descs, + DstBuffer& dst_bufs, + Number thread_scratch_id) + { + RunRead(src_descs, src_bufs, thread_scratch_id); + RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp new file mode 100644 index 0000000000..4dba95e5d3 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseImpl + : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static index_t GetLowestStrideDim(const std::array& strides) + { + index_t most_continous_dim = NumDim - 1; + index_t most_continous_dim_stride = strides[most_continous_dim]; + for(index_t dim = 0; dim < NumDim; dim++) + { + if(strides[dim] < most_continous_dim_stride) + { + most_continous_dim_stride = strides[dim]; + most_continous_dim = dim; + } + } + return most_continous_dim; + } + + template + static auto PadInputOutputDescriptor(const InOutDescriptor& desc) + { + const auto M0 = desc.GetLength(I0); + const auto M1 = desc.GetLength(I1); + const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0; + const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; + } + + static auto GenerateBatchDimsLenghtsTuple(const std::array& lengths, + const index_t M0_dim, + const index_t M1_dim) + { + // Generate batch dims, they will be merged to M0 + // Add one more dim than needed in case that M0 is equal to M1 + // If M0 is equal to M1, then will be one more batch dim + std::array batch_dims; + index_t batch_dim = 0; + for(index_t i = 0; i < NumDim; i++) + { + if(i != M0_dim && i != M1_dim) + { + batch_dims[batch_dim] = lengths[i]; + batch_dim++; + } + } + // Add dummy dim if M0_dim is not equal to M1_dim + if(M0_dim != M1_dim && NumDim >= 2) + batch_dims[NumDim - 2] = 1; + return generate_tuple([&](auto I) { return batch_dims[I]; }, Number{}); + } + + static auto MakeDescriptor(const std::array& lengths, + const std::array& in_strides, + const std::array& out_strides, + const std::array& desc_strides) + { + const auto M0_dim = GetLowestStrideDim(out_strides); + const auto M1_dim = GetLowestStrideDim(in_strides); + + // If M0_dim is equal to M1_dim, then make M0_dim dummy + const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim]; + const auto M1 = lengths[M1_dim]; + const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim]; + const auto M1_stride = desc_strides[M1_dim]; + + const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim); + const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim); + + const auto desc = make_naive_tensor_descriptor( + concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)), + concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride))); + // Merged batch dims with M0 + const auto transforms = + make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))), + make_pass_through_transform(M1)); + using BatchElemsSequence = + typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type; + const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence{}); + const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{}); + // desc: (merged_dims + M0, M1) + auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + return PadInputOutputDescriptor(merged_desc); + } + + template + static auto GenerateInOutGridDescTuple() + { + std::array ones; + for(index_t d = 0; d < NumDim; d++) + { + ones[d] = 1; + } + + return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); }, + Number{}); + }; + + using InGridDescTuple = decltype(GenerateInOutGridDescTuple()); + using OutGridDescTuple = decltype(GenerateInOutGridDescTuple()); + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseOp = GridwiseElementwise; + + using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto in_grid_desc_tuple = generate_tuple( + [&](auto src_i) { + // Use Strides from first tensor to assert that M0 dim and + // M1 dim are the same for each tensor. + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.inStridesArray_[src_i]); + }, + Number{}); + + auto out_grid_desc_tuple = generate_tuple( + [&](auto dst_i) { + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.outStridesArray_[dst_i]); + }, + Number{}); + + const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + + const auto block_2_tile_map = Block2TileMap(M0, M1); + const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1); + + const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) == + GetLowestStrideDim(arg.outStridesArray_[I0]); + + const auto kernel = in_out_same_vector_dim + ? kernel_elementwise + : kernel_elementwise; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + in_grid_desc_tuple, + out_grid_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + block_2_tile_map, + arg.elementwise_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]); + const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]); + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t M_dim) { + if(scalarPerVector == 1) + { + return true; + } + if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0) + { + return true; + } + return false; + }; + + bool is_valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 && + M1PerThread % InScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 && + M1PerThread % OutScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim); + }); + + return is_valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<"; + str << NumDim << ", "; + str << BlockSize << ", "; + str << M0PerBlock << ", "; + str << M1PerBlock << ", "; + str << M0PerThread << ", "; + str << M1PerThread << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp new file mode 100644 index 0000000000..2a906a1432 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) +{ + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + block_2_tile_map, + elementwise_op); +} + +template +struct GridwiseElementwise +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple, + const OutGridDescTuple& out_grid_desc_tuple, + const InDataTypePointerTuple& p_in_global_tuple, + const OutDataTypePointerTuple& p_out_global_tuple, + const Block2TileMap& block_2_tile_map, + const ElementwiseOperation& elementwise_op) + { + + constexpr auto src_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return DataType{}; + }, + Number{}); + + constexpr auto dst_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return DataType{}; + }, + Number{}); + + const auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); + const index_t m1_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); + const auto thread_grid_offset = + make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + + using ThisThreadBlock = ThisThreadBlock; + // If src and dst have same vector dim, then: + // M0 dim - for src and dst vector load/store + // else: + // M0 dim - for dst vector load + // M1 dim - for src vector store + using SrcDimAccessOrder = Sequence<0, 1>; + using DstDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + using SrcVectorDim = Number<1>; + using DstVectorDim = std::conditional_t, Number<0>>; + + using ThreadClusterLengths = + Sequence{}, Number{}>; + + auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2< + ThisThreadBlock, + ElementwiseOperation, + uniform_sequence_gen_t(InMemoryDataOperationEnum::Set)>, + Sequence, + ThreadClusterLengths, + ThreadClusterArrangeOrder, + decltype(src_datas), + decltype(dst_datas), + InGridDescTuple, + OutGridDescTuple, + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim{}, + DstVectorDim{}, + InScalarPerVectorSeq, + OutScalarPerVectorSeq, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t>{in_grid_desc_tuple, + thread_grid_offset, + out_grid_desc_tuple, + thread_grid_offset, + elementwise_op}; + global_to_global_transfer.Run( + in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp new file mode 100644 index 0000000000..f0d793456d --- /dev/null +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + src_coords_(src_i) = + make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + dst_coords_(dst_i) = + make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]); + }); + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access_tuple = generate_tuple( + [&](auto src_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return SliceLengths{} / src_scalar_per_access_tuple.At(src_i); + static_assert( + SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector"); + }, + Number{}); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i), + src_dim_access_order); + }, + Number{}); + + // make forward steps + const auto src_forward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto src_backward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -src_scalar_per_access_tuple.At(src_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_ford>{}( + [&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths_tuple[j] + + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_src_access_idx[i] + : ordered_src_access_lengths_tuple.At(src_i)[i] - + 1 - ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access_tuple.At(src_i); + }(); + + constexpr auto src_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_descs.At(src_i), src_coords_.At(src_i)); + + using src_vector_type = vector_type_maker_t, + SrcsScalarPerVector::At(src_i)>; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = + src_vector_type{src_bufs.At(src_i).template Get( + src_coords_.At(src_i).GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .At(src_i) + .template SetAsType( + src_data_idx_seq, + src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < + ordered_src_access_lengths_tuple.At(src_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == + ordered_src_access_lengths_tuple.At(src_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + } + }); + }); + }); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + // move src coordinate back to slice origin (or not) + if constexpr(SrcsResetCoordinateAfterRun::At(src_i)) + { + const auto src_reset_step = make_tensor_coordinate_step( + src_descs.At(src_i), GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step); + } + }); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { + // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + // (it requires to add Elementwise support in transpose_vectors) + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); + + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access_tuple = generate_tuple( + [&](auto dst_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); }, + Number{}); + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { + return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i), + dst_dim_access_order); + }, + Number{}); + + // make forward steps + const auto dst_forward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -dst_scalar_per_access_tuple.At(dst_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_ford>{}( + [&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths_tuple.At(dst_i)[i] - + 1 - ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access_tuple.At(dst_i); + }(); + + constexpr auto dst_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + dst_descs.At(dst_i), dst_coords_.At(dst_i)); + + using dst_vector_type = vector_type_maker_t, + DstsScalarPerVector::At(dst_i)>; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_tuple_.At(dst_i).template GetAsType( + dst_data_idx_seq)}; + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(dst_i.value)); + + // copy data from dst_vector_container to dst_buf + dst_bufs.At(dst_i).template Update( + dst_coords_.At(dst_i).GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < + ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == + ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + } + }); + }); + }); + + // move dst coordinate back to slice origin (or not) + static_for<0, nDst, 1>{}([&](auto dst_i) { + if constexpr(DstsResetCoordinateAfterRun::At(dst_i)) + { + const auto dst_reset_step = make_tensor_coordinate_step( + dst_descs.At(dst_i), GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step); + } + }); + } + + template + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + template + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access.At(dst_i); + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + const Index& src_slice_origin_step_idx) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcsResetCoordinateAfterRun::At(src_i) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step); + }); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + const Index& dst_slice_origin_step_idx) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstsResetCoordinateAfterRun::At(dst_i) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step); + }); + } + + template + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(src_access_lengths), + Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + template + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(dst_access_lengths), + Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto MakeSrcThreadScratchTuple() + { + return generate_tuple( + [&](auto src_i) { + constexpr auto src_thread_scratch_desc = + decltype(GetSrcThreadScratchDescriptor()){}; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer, + SrcsScalarPerVector::At(src_i), + decltype(src_thread_scratch_desc), + true>; + return SrcThreadScratch{}; + }, + Number{}); + } + + __device__ static constexpr auto MakeDstThreadScratchTuple() + { + return generate_tuple( + [&](auto dst_i) { + constexpr auto dst_thread_scratch_desc = + decltype(GetDstThreadScratchDescriptor()){}; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer, + DstsScalarPerVector::At(dst_i), + decltype(dst_thread_scratch_desc), + true>; + return DstThreadScratch{}; + }, + Number{}); + } + + private: + using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple()); + using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple()); + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratchTuple dst_thread_scratch_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp index 4b3f40e214..4f5d022f9c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale.hpp @@ -7,7 +7,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" @@ -19,125 +19,67 @@ namespace instance { #ifdef CK_ENABLE_FP16 void add_device_permute_scale_1d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 1>>>&); + std::vector, ck::Tuple, element_wise::Scale, 1>>>&); void add_device_permute_scale_2d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 2>>>&); + std::vector, ck::Tuple, element_wise::Scale, 2>>>&); void add_device_permute_scale_3d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 3>>>&); + std::vector, ck::Tuple, element_wise::Scale, 3>>>&); void add_device_permute_scale_4d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 4>>>&); + std::vector, ck::Tuple, element_wise::Scale, 4>>>&); void add_device_permute_scale_5d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 5>>>&); + std::vector, ck::Tuple, element_wise::Scale, 5>>>&); void add_device_permute_scale_6d_f16_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 6>>>&); + std::vector, ck::Tuple, element_wise::Scale, 6>>>&); #endif #ifdef CK_ENABLE_FP32 void add_device_permute_scale_1d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 1>>>&); + std::vector, ck::Tuple, element_wise::Scale, 1>>>&); void add_device_permute_scale_2d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 2>>>&); + std::vector, ck::Tuple, element_wise::Scale, 2>>>&); void add_device_permute_scale_3d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 3>>>&); + std::vector, ck::Tuple, element_wise::Scale, 3>>>&); void add_device_permute_scale_4d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 4>>>&); + std::vector, ck::Tuple, element_wise::Scale, 4>>>&); void add_device_permute_scale_5d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 5>>>&); + std::vector, ck::Tuple, element_wise::Scale, 5>>>&); void add_device_permute_scale_6d_f32_instances( - std::vector, - ck::Tuple, - PassThrough, - element_wise::UnarySquare, - Scale, - 6>>>&); + std::vector, ck::Tuple, element_wise::Scale, 6>>>&); #endif template struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceElementwise> + ck::tensor_operation::device:: + DeviceElementwise> { - using DeviceOp = DeviceElementwise; + using DeviceOp = + DeviceElementwise; static auto GetInstances() { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp index a672ab22df..8a22005413 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp @@ -2,7 +2,7 @@ // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/utility/data_type.hpp" namespace ck { @@ -13,26 +13,175 @@ namespace instance { using F16 = ck::half_t; using F32 = float; -using Pass = ck::tensor_operation::element_wise::PassThrough; -using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; -using Scale = ck::tensor_operation::element_wise::Scale; - // clang-format off -template +template using device_permute_scale_f16_instances = std::tuple < - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + +#if 0 + // Disabled instances to improve compilation time + // They listed here to show other possible combinations of parameters + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, +#endif + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> + >; -template +template using device_permute_scale_f32_instances = std::tuple< - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 1, ck::Sequence<1>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 8, ck::Sequence<8>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 4, ck::Sequence<4>, ck::Sequence<1>>, - DeviceElementwiseImpl, ck::Tuple, Pass, UnaryOp, Scale, NDims, 2, ck::Sequence<2>, ck::Sequence<1>> + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 256, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 128, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 32, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 64, 8, 8, ck::Sequence<1, 0>, ck::Sequence<8>, ck::Sequence<8>>, + +#if 0 + // Disabled instances to improve compilation time + // They listed here to show other possible combinations of parameters + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 512, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 512, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 256, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 256, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 128, 64, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 128, 16, 16, ck::Sequence<1, 0>, ck::Sequence<16>, ck::Sequence<16>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 256, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 128, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 64, 4, 8, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 256, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 128, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 256, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 64, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 128, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 64, 16, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 32, 8, 4, ck::Sequence<1, 0>, ck::Sequence<4>, ck::Sequence<4>>, +#endif + + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 64, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 128, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 256, 32, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 64, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 32, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 16, 128, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 128, 128, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 32, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 16, 64, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 64, 64, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 32, 16, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>>, + DeviceElementwiseImpl, ck::Tuple, ElementwiseOp, NDims, 32, 16, 32, 4, 4, ck::Sequence<1, 0>, ck::Sequence<1>, ck::Sequence<1>> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt index 86652c0bf6..fc0da56a96 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/permute_scale/CMakeLists.txt @@ -1,7 +1,13 @@ add_instance_library(device_permute_scale_instance - device_permute_scale_1d_instances.cpp - device_permute_scale_2d_instances.cpp - device_permute_scale_3d_instances.cpp - device_permute_scale_4d_instances.cpp - device_permute_scale_5d_instances.cpp - device_permute_scale_6d_instances.cpp) + device_permute_scale_1d_fp16_instances.cpp + device_permute_scale_2d_fp16_instances.cpp + device_permute_scale_3d_fp16_instances.cpp + device_permute_scale_4d_fp16_instances.cpp + device_permute_scale_5d_fp16_instances.cpp + device_permute_scale_6d_fp16_instances.cpp + device_permute_scale_1d_fp32_instances.cpp + device_permute_scale_2d_fp32_instances.cpp + device_permute_scale_3d_fp32_instances.cpp + device_permute_scale_4d_fp32_instances.cpp + device_permute_scale_5d_fp32_instances.cpp + device_permute_scale_6d_fp32_instances.cpp) diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp index 77d3baf4d3..4ee9c1b1c1 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_1d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<1>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_1d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 1>>>& instances) +void add_device_permute_scale_1d_f16_instances( + std::vector, ck::Tuple, Scale, 1>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<1>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<1, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp new file mode 100644 index 0000000000..672acda071 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_1d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_1d_f32_instances( + std::vector, ck::Tuple, Scale, 1>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<1, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp index 399b6b0490..b4a5b107f6 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_2d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<2>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_2d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 2>>>& instances) +void add_device_permute_scale_2d_f16_instances( + std::vector, ck::Tuple, Scale, 2>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<2>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<2, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp new file mode 100644 index 0000000000..5b7b353fc3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_2d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_2d_f32_instances( + std::vector, ck::Tuple, Scale, 2>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<2, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp index 29f2f9fd5c..63876cbc44 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_3d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<3>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_3d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 3>>>& instances) +void add_device_permute_scale_3d_f16_instances( + std::vector, ck::Tuple, Scale, 3>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<3>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<3, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp new file mode 100644 index 0000000000..f8772967dd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_3d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_3d_f32_instances( + std::vector, ck::Tuple, Scale, 3>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<3, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp index 3ad1d59e66..553772e1db 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_4d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<4>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_4d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 4>>>& instances) +void add_device_permute_scale_4d_f16_instances( + std::vector, ck::Tuple, Scale, 4>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<4>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<4, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp new file mode 100644 index 0000000000..f1ecc0ccf0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_4d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_4d_f32_instances( + std::vector, ck::Tuple, Scale, 4>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<4, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp index 6a4383bc95..adb391888a 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_5d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<5>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_5d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 5>>>& instances) +void add_device_permute_scale_5d_f16_instances( + std::vector, ck::Tuple, Scale, 5>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<5>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<5, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp new file mode 100644 index 0000000000..ed53e09b7a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_5d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_5d_f32_instances( + std::vector, ck::Tuple, Scale, 5>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<5, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp similarity index 56% rename from library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp rename to library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp index 71e5867e9a..abf630e433 100644 --- a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_instances.cpp +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp16_instances.cpp @@ -9,18 +9,13 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_permute_scale_6d_f16_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) -{ - add_device_operation_instances(instances, device_permute_scale_f16_instances<6>{}); -} +using Scale = element_wise::Scale; -void add_device_permute_scale_6d_f32_instances( - std::vector, ck::Tuple, Pass, UnaryOp, Scale, 6>>>& instances) +void add_device_permute_scale_6d_f16_instances( + std::vector, ck::Tuple, Scale, 6>>>& + instances) { - add_device_operation_instances(instances, device_permute_scale_f32_instances<6>{}); + add_device_operation_instances(instances, device_permute_scale_f16_instances<6, Scale>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp new file mode 100644 index 0000000000..fbdace20a9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/permute_scale/device_permute_scale_6d_fp32_instances.cpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/permute_scale/device_permute_scale_instances.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Scale = element_wise::Scale; + +void add_device_permute_scale_6d_f32_instances( + std::vector, ck::Tuple, Scale, 6>>>& + instances) +{ + add_device_operation_instances(instances, device_permute_scale_f32_instances<6, Scale>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_permute_scale_impl.hpp b/profiler/include/profiler/profile_permute_scale_impl.hpp index 5bc7c029f4..c69e36142d 100644 --- a/profiler/include/profiler/profile_permute_scale_impl.hpp +++ b/profiler/include/profiler/profile_permute_scale_impl.hpp @@ -8,9 +8,9 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp" #include "ck/library/tensor_operation_instance/gpu/permute_scale.hpp" @@ -21,23 +21,12 @@ #include "ck/library/utility/literals.hpp" namespace ck { -template +template void reference_permute_scale(HostTensorB& b_tensor, const HostTensorA& a_tensor, - AElementOp a_tensor_op, - BElementOp b_tensor_op, - ScaleElementOp scale_op) + ElementOp tensor_op) { - b_tensor.ForEach([&](auto& self, auto idx) { - auto tmp_val = a_tensor(idx); - b_tensor_op(tmp_val, tmp_val); - scale_op(tmp_val, tmp_val); - a_tensor_op(self(idx), tmp_val); - }); + b_tensor.ForEach([&](auto& self, auto idx) { tensor_op(self(idx), a_tensor(idx)); }); } namespace profiler { @@ -54,9 +43,7 @@ bool profile_permute_scale_impl(int do_verification, bool pass = true; bool instance_found = false; - using ElementOp = ck::tensor_operation::element_wise::PassThrough; - using UnaryOp = ck::tensor_operation::element_wise::UnarySquare; - using Scale = ck::tensor_operation::element_wise::Scale; + using ElementOp = ck::tensor_operation::element_wise::Scale; float scale = 2.f; Tensor a(lengths_vector, input_strides_vector); @@ -80,12 +67,8 @@ bool profile_permute_scale_impl(int do_verification, std::array input = {a_device_buf.GetDeviceBuffer()}; std::array output = {b_device_buf.GetDeviceBuffer()}; - using DeviceOp = ck::tensor_operation::device::DeviceElementwise, - ck::Tuple, - ElementOp, - UnaryOp, - Scale, - NumDim>; + using DeviceOp = ck::tensor_operation::device:: + DeviceElementwise, ck::Tuple, ElementOp, NumDim>; // get device op instances const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -100,7 +83,7 @@ bool profile_permute_scale_impl(int do_verification, if(do_verification) { - reference_permute_scale(host_b, a, ElementOp{}, UnaryOp{}, Scale{scale}); + reference_permute_scale(host_b, a, ElementOp{scale}); } auto copy = [](const auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); }; @@ -113,14 +96,8 @@ bool profile_permute_scale_impl(int do_verification, for(auto& op_ptr : op_ptrs) { - auto argument_ptr = op_ptr->MakeArgumentPointer(lengths, - {input_strides}, - {output_strides}, - input, - output, - ElementOp{}, - UnaryOp{}, - Scale{scale}); + auto argument_ptr = op_ptr->MakeArgumentPointer( + lengths, {input_strides}, {output_strides}, input, output, ElementOp{scale}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); @@ -141,6 +118,7 @@ bool profile_permute_scale_impl(int do_verification, if(do_log) { LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "host_b: ", host_b.mData, ",") << std::endl; LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; } } diff --git a/profiler/src/profile_permute_scale.cpp b/profiler/src/profile_permute_scale.cpp index 921b9b9a69..8ebb2289ed 100644 --- a/profiler/src/profile_permute_scale.cpp +++ b/profiler/src/profile_permute_scale.cpp @@ -37,6 +37,20 @@ static void print_helper_msg() // clang-format on } +void init_strides(const std::vector& lengths, + const std::vector& dims_order, + std::vector& strides) +{ + + ck::index_t stride = 1; + for(ck::index_t d = lengths.size() - 1; d >= 0; d--) + { + ck::index_t dim = dims_order[d]; + strides[dim] = stride; + stride *= lengths[dim]; + } +} + } // namespace int profile_permute_scale(int argc, char* argv[]) @@ -58,16 +72,21 @@ int profile_permute_scale(int argc, char* argv[]) const int num_dims = dims_argc / 3; std::vector lengths(num_dims); - std::vector input_strides(num_dims); - std::vector output_strides(num_dims); + std::vector input_dims_order(num_dims); + std::vector output_dims_order(num_dims); for(int i = 0; i < num_dims; i++) { - lengths[i] = std::stoi(argv[control_argc + i]); - input_strides[i] = std::stoi(argv[control_argc + num_dims + i]); - output_strides[i] = std::stoi(argv[control_argc + 2 * num_dims + i]); + lengths[i] = std::stoi(argv[control_argc + i]); + input_dims_order[i] = std::stoi(argv[control_argc + num_dims + i]); + output_dims_order[i] = std::stoi(argv[control_argc + 2 * num_dims + i]); } + std::vector input_strides(num_dims); + std::vector output_strides(num_dims); + init_strides(lengths, input_dims_order, input_strides); + init_strides(lengths, output_dims_order, output_strides); + using F32 = float; using F16 = ck::half_t; diff --git a/script/profile_permute_scale.sh b/script/profile_permute_scale.sh new file mode 100755 index 0000000000..945d10f47b --- /dev/null +++ b/script/profile_permute_scale.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +## GPU visibility +export HIP_VISIBLE_DEVICES=0 +DRIVER="../build/bin/ckProfiler" +echo $DRIVER +OP=$1 +DATATYPE=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +TIME=$6 + + +# 1D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 67108864 0 0 + +# # 2D +# ######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 0 1 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8192 8192 1 0 0 1 + +# 3D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 0 1 2 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 1024 8192 2 1 0 0 1 2 + +# 4D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 0 1 2 3 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 512 8192 3 2 1 0 0 1 2 3 + +# 5D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 0 1 2 3 4 4 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 256 8192 4 3 2 1 0 0 1 2 3 4 + + # 6D +######## op datatype verify init log time dims in_strides_order out_strides_order + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 0 1 2 3 4 5 5 4 3 2 1 0 + $DRIVER $OP $DATATYPE $VERIFY $INIT $LOG $TIME 8 2 2 2 128 8192 5 4 3 2 1 0 0 1 2 3 4 5 + diff --git a/test/permute_scale/test_permute_scale.cpp b/test/permute_scale/test_permute_scale.cpp index 780f6d6edb..e40d4861cf 100644 --- a/test/permute_scale/test_permute_scale.cpp +++ b/test/permute_scale/test_permute_scale.cpp @@ -52,40 +52,40 @@ TYPED_TEST_SUITE(TestPermute, KernelTypes); TYPED_TEST(TestPermute, Test1D) { constexpr ck::index_t NumDims = 1; - this->template Run({8}, {1}, {2}); - this->template Run({8}, {2}, {1}); + this->template Run({16}, {1}, {1}); + this->template Run({16}, {1}, {2}); this->template Run({1}, {1}, {1}); } TYPED_TEST(TestPermute, Test2D) { constexpr ck::index_t NumDims = 2; - this->template Run({8, 4}, {4, 1}, {1, 8}); - this->template Run({8, 4}, {1, 8}, {4, 1}); + this->template Run({8, 16}, {16, 1}, {1, 8}); + this->template Run({8, 16}, {1, 8}, {16, 1}); this->template Run({1, 1}, {1, 1}, {1, 1}); } TYPED_TEST(TestPermute, Test3D) { constexpr ck::index_t NumDims = 3; - this->template Run({2, 4, 4}, {16, 4, 1}, {1, 2, 8}); - this->template Run({2, 4, 4}, {1, 2, 8}, {16, 4, 1}); + this->template Run({8, 2, 8}, {16, 8, 1}, {1, 8, 16}); + this->template Run({8, 2, 8}, {1, 8, 16}, {16, 8, 1}); this->template Run({1, 1, 1}, {1, 1, 1}, {1, 1, 1}); } TYPED_TEST(TestPermute, Test4D) { constexpr ck::index_t NumDims = 4; - this->template Run({2, 4, 4, 4}, {64, 16, 4, 1}, {1, 2, 8, 32}); - this->template Run({2, 4, 4, 4}, {1, 2, 8, 32}, {64, 16, 4, 1}); + this->template Run({8, 2, 3, 8}, {48, 24, 8, 1}, {1, 8, 16, 48}); + this->template Run({8, 2, 3, 8}, {1, 8, 16, 48}, {48, 24, 8, 1}); this->template Run({1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}); } TYPED_TEST(TestPermute, Test5D) { constexpr ck::index_t NumDims = 5; - this->template Run({2, 4, 4, 4, 4}, {256, 64, 16, 4, 1}, {1, 2, 8, 32, 128}); - this->template Run({2, 4, 4, 4, 4}, {1, 2, 8, 32, 128}, {256, 64, 16, 4, 1}); + this->template Run({8, 2, 3, 4, 8}, {192, 96, 32, 8, 1}, {1, 8, 16, 48, 192}); + this->template Run({8, 2, 3, 4, 8}, {1, 8, 16, 48, 192}, {192, 96, 32, 8, 1}); this->template Run({1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {1, 1, 1, 1, 1}); } @@ -93,8 +93,8 @@ TYPED_TEST(TestPermute, Test6D) { constexpr ck::index_t NumDims = 6; this->template Run( - {2, 4, 4, 4, 4, 4}, {1024, 256, 64, 16, 4, 1}, {1, 2, 8, 32, 128, 512}); + {8, 2, 3, 4, 5, 8}, {960, 480, 160, 40, 8, 1}, {1, 8, 16, 48, 192, 960}); this->template Run( - {2, 4, 4, 4, 4, 4}, {1, 2, 8, 32, 128, 512}, {1024, 256, 64, 16, 4, 1}); + {8, 2, 3, 4, 5, 8}, {1, 8, 16, 48, 192, 960}, {960, 480, 160, 40, 8, 1}); this->template Run({1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 1}); } From 2ae16e901f75594022848a05ba9c1b6d0e3e4d6d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 22 Mar 2024 07:58:36 -0700 Subject: [PATCH 35/36] Bump rocm-docs-core from 0.37.0 to 0.37.1 in /docs/sphinx (#1211) Bumps [rocm-docs-core](https://github.com/RadeonOpenCompute/rocm-docs-core) from 0.37.0 to 0.37.1. - [Release notes](https://github.com/RadeonOpenCompute/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/develop/CHANGELOG.md) - [Commits](https://github.com/RadeonOpenCompute/rocm-docs-core/compare/v0.37.0...v0.37.1) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index ae92cc6c10..76ec2700ca 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==0.37.0 +rocm-docs-core==0.37.1 sphinxcontrib-bibtex==2.6.2 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 43853dd3fa..ab2415f0c9 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -111,7 +111,7 @@ requests==2.31.0 # via # pygithub # sphinx -rocm-docs-core==0.37.0 +rocm-docs-core==0.37.1 # via -r requirements.in six==1.16.0 # via From cc1f733d0eaab81c2185888668479fb30b200bdb Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Fri, 22 Mar 2024 15:39:11 -0700 Subject: [PATCH 36/36] allow the CI to pass even if can't connect to db (#1214) --- Jenkinsfile | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index ec3cbd0e27..654c7274f4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -619,6 +619,8 @@ def process_results(Map conf=[:]){ dir("script"){ if (params.RUN_FULL_QA){ // unstash perf files to master + unstash "ckprofiler_0.2.0_amd64.deb" + sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" unstash "perf_gemm.log" unstash "perf_resnet50_N256.log" unstash "perf_resnet50_N4.log" @@ -632,8 +634,6 @@ def process_results(Map conf=[:]){ unstash "perf_onnx_gemm.log" unstash "perf_mixed_gemm.log" sh "./process_qa_data.sh" - unstash "ckprofiler_0.2.0_amd64.deb" - sh "sshpass -p ${env.ck_deb_pw} scp -o StrictHostKeyChecking=no ckprofiler_0.2.0_amd64.deb ${env.ck_deb_user}@${env.ck_deb_ip}:/var/www/html/composable_kernel/" } else{ // unstash perf files to master @@ -645,10 +645,13 @@ def process_results(Map conf=[:]){ } } catch(e){ - echo "throwing error exception while processing performance test results" + echo "Throwing error exception while processing performance test results" echo 'Exception occurred: ' + e.toString() throw e } + finally{ + echo "Finished processing performance test results" + } } } }