From 7ac379428408337a231a86f8a8b7353b5b45aa2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <188998872+vpietila-amd@users.noreply.github.com> Date: Sun, 25 Jan 2026 14:42:23 +0200 Subject: [PATCH 01/28] Add new instances for merging multiple fwd conv groups into a single GEMM batch. Allow group merging for C > 1 when vector load/store size is 1 for the output tensor. (#3639) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Ville Pietilä <> --- ...vice_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp | 2 +- ...ice_grouped_conv_fwd_xdl_merged_groups_instance.hpp | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cc343f6f69..d3e0d6057d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1513,7 +1513,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle if constexpr(NumGroupsToMerge > 1) { - if(!(C == 1)) + if(!(C == 1) && CDEBlockTransferScalarPerVector_NPerBlock > 1) { return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 944e68f192..18abcb1613 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -116,9 +116,13 @@ using device_grouped_conv_fwd_xdl_merged_groups_f16_instances_2x = std::tuple< //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // Instances with NumGroupsPerBatch > 1 - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S< 4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 16>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, F16, F16, LoopScheduler::Default, 32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1, F16, F16, LoopScheduler::Default, 8> // clang-format on >; From 054c437dec3bc0d0059f045dc768b950db315846 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 26 Jan 2026 09:23:19 -0800 Subject: [PATCH 02/28] add dockerfile for manylinux (#3651) --- Dockerfile.manylinux | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 Dockerfile.manylinux diff --git a/Dockerfile.manylinux b/Dockerfile.manylinux new file mode 100644 index 0000000000..0683bcd4a6 --- /dev/null +++ b/Dockerfile.manylinux @@ -0,0 +1,101 @@ +FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest +ARG DEBIAN_FRONTEND=noninteractive +ARG ROCMVERSION=7.2 +ARG compiler_version="" +ARG compiler_commit="" +ARG CK_SCCACHE="" +ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/ +ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn +ENV DEBIAN_FRONTEND=noninteractive + +USER root + +# Add rocm repository +RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y + +RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \ + dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \ + dnf update -y && \ + dnf install python3-setuptools python3-wheel -y && \ + dnf install rocm-dev -y + +## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined +ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache +ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin +ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION} +ENV CK_SCCACHE=$CK_SCCACHE +RUN if [ "$CK_SCCACHE" != "" ]; then \ + mkdir -p ${SCCACHE_INSTALL_LOCATION} && \ + curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \ + chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \ + fi + +# Install dependencies +RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \ + cmake \ + clang-tools-extra \ + gcc-c++ \ + libstdc++ \ + libstdc++-devel \ + libstdc++-static \ + git \ + hip-rocclr \ + jq \ + mpich \ + net-tools \ + pkg-config \ + redis \ + sshpass \ + stunnel \ + vim \ + nano \ + zip \ + openssh-server \ + kmod && \ + dnf clean all && \ + rm -rf /var/lib/apt/lists/* && \ + rm -rf amdgpu-install* && \ +#Install latest ccache + git clone https://github.com/ccache/ccache.git && \ + cd ccache && mkdir build && cd build && cmake .. && make install && \ +#Install ClangBuildAnalyzer + git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \ + cd ClangBuildAnalyzer/ && \ + make -f projects/make/Makefile && \ + cd / && \ +#Install latest cppcheck + git clone https://github.com/danmar/cppcheck.git && \ + cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \ + cd / && \ +# Install packages for processing the performance results + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ +# Add render group + groupadd -f render && \ +# Install the new rocm-cmake version + git clone -b master https://github.com/ROCm/rocm-cmake.git && \ + cd rocm-cmake && mkdir build && cd build && \ + cmake .. && cmake --build . && cmake --build . --target install + +WORKDIR / +# Add alternative compilers, if necessary +ENV compiler_version=$compiler_version +ENV compiler_commit=$compiler_commit +RUN sh -c "echo compiler version = '$compiler_version'" && \ + sh -c "echo compiler commit = '$compiler_commit'" + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + +RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \ + git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \ + cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \ + cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \ + make -j 8 ; \ + else echo "using the release compiler"; \ + fi + From de59c0716c631edfa4742e4309ee11d4379ef6e8 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:08:55 -0800 Subject: [PATCH 03/28] Optimize sequence metaprogramming utilities to reduce template instantiation depth (#3585) This change significantly improves compile-time performance by reducing template instantiation depth for sequence generation and merging operations: Optimizations: - sequence_gen: Reduce instantiation depth from O(log N) to O(1) by using __make_integer_seq to generate indices in a single step, then applying the functor via pack expansion - uniform_sequence_gen: Similarly optimized to O(1) depth using __make_integer_seq with a helper that applies a constant value via pack expansion - sequence_merge: Reduce depth from O(N) to O(log N) using binary tree reduction strategy. Added direct concatenation specializations for 1-4 sequences to avoid recursion in common cases, falling back to binary tree merging for 5+ sequences Documentation: - Added extensive inline comments explaining why sequence_merge cannot achieve O(1) depth like sequence_gen (requires computing cumulative sequence lengths from heterogeneous inputs, inherently requiring recursion) - Documented the binary tree reduction approach and why it's superior to fold expressions for this use case Testing: - Added comprehensive unit tests for uniform_sequence_gen with different values, sizes, and edge cases - Added tests for sequence_gen with custom functors (double, square, identity, constant) to verify the new implementation works with arbitrary functors - Added tests for sequence_merge with 4, 5, and many sequences to verify both the direct concatenation path and binary tree reduction path - Added tests for empty sequence edge cases --- include/ck/utility/sequence.hpp | 152 +++++++++++++----- .../ck/utility/statically_indexed_array.hpp | 1 + test/util/unit_sequence.cpp | 134 +++++++++++++++ 3 files changed, 247 insertions(+), 40 deletions(-) diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 6e68690048..3a45d52bd3 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -199,55 +199,113 @@ template using make_index_sequence = typename __make_integer_seq::seq_type; -// merge sequence -template -struct sequence_merge +// merge sequence - optimized to avoid recursive instantiation +// +// Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1) +// instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why: +// +// - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each +// element can be computed independently: output[i] = f(i) +// +// - sequence_merge takes MULTIPLE input sequences with different, unknown lengths. +// To compute output[i], we need to know: +// 1. Which input sequence contains this index +// 2. The offset within that sequence +// This requires computing cumulative sequence lengths, which requires recursion/iteration. +// +// Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth: +// - Base cases handle 1-4 sequences directly (O(1) for common cases) +// - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...) +// - This gives O(log N) depth, which is optimal for merging heterogeneous sequences +// +// Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to +// linear dependency chain, so binary tree is superior. +// +namespace detail { + +// Helper to concatenate multiple sequences in one step using fold expression +template +struct sequence_merge_impl; + +// Base case: single sequence +template +struct sequence_merge_impl> { - using type = typename sequence_merge::type>::type; + using type = Sequence; }; +// Two sequences: direct concatenation template -struct sequence_merge, Sequence> +struct sequence_merge_impl, Sequence> { using type = Sequence; }; -template -struct sequence_merge +// Three sequences: direct concatenation (avoids one level of recursion) +template +struct sequence_merge_impl, Sequence, Sequence> { - using type = Seq; + using type = Sequence; }; -// generate sequence +// Four sequences: direct concatenation +template +struct sequence_merge_impl, Sequence, Sequence, Sequence> +{ + using type = Sequence; +}; + +// General case: binary tree reduction (O(log N) depth instead of O(N)) +template +struct sequence_merge_impl +{ + // Merge pairs first, then recurse + using left = typename sequence_merge_impl::type; + using right = typename sequence_merge_impl::type; + using type = typename sequence_merge_impl::type; +}; + +} // namespace detail + +template +struct sequence_merge +{ + using type = typename detail::sequence_merge_impl::type; +}; + +template <> +struct sequence_merge<> +{ + using type = Sequence<>; +}; + +// generate sequence - optimized using __make_integer_seq to avoid recursive instantiation +namespace detail { + +// Helper that applies functor F to indices and produces a Sequence +// __make_integer_seq produces sequence_gen_helper +template +struct sequence_gen_helper +{ + // Apply a functor F to all indices at once via pack expansion (O(1) depth) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + template struct sequence_gen { - template - struct sequence_gen_impl - { - static constexpr index_t NRemainLeft = NRemain / 2; - static constexpr index_t NRemainRight = NRemain - NRemainLeft; - static constexpr index_t IMiddle = IBegin + NRemainLeft; + using type = + typename __make_integer_seq::template apply; +}; - using type = typename sequence_merge< - typename sequence_gen_impl::type, - typename sequence_gen_impl::type>::type; - }; - - template - struct sequence_gen_impl - { - static constexpr index_t Is = G{}(Number{}); - using type = Sequence; - }; - - template - struct sequence_gen_impl - { - using type = Sequence<>; - }; - - using type = typename sequence_gen_impl<0, NSize, F>::type; +template +struct sequence_gen<0, F> +{ + using type = Sequence<>; }; // arithmetic sequence @@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1> using type = typename __make_integer_seq::type; }; -// uniform sequence +// uniform sequence - optimized using __make_integer_seq +namespace detail { + +template +struct uniform_sequence_helper +{ + // Apply a constant value to all indices via pack expansion + template + using apply = Sequence<((void)Is, Value)...>; +}; + +} // namespace detail + template struct uniform_sequence_gen { - struct F - { - __host__ __device__ constexpr index_t operator()(index_t) const { return I; } - }; + using type = typename __make_integer_seq:: + template apply; +}; - using type = typename sequence_gen::type; +template +struct uniform_sequence_gen<0, I> +{ + using type = Sequence<>; }; // reverse inclusive scan (with init) sequence diff --git a/include/ck/utility/statically_indexed_array.hpp b/include/ck/utility/statically_indexed_array.hpp index d0735a32f6..f3d73e84a7 100644 --- a/include/ck/utility/statically_indexed_array.hpp +++ b/include/ck/utility/statically_indexed_array.hpp @@ -20,6 +20,7 @@ struct tuple_concat, Tuple> using type = Tuple; }; +// StaticallyIndexedArrayImpl uses binary split for O(log N) depth template struct StaticallyIndexedArrayImpl { diff --git a/test/util/unit_sequence.cpp b/test/util/unit_sequence.cpp index f09fd86e06..9e62b9a6c0 100644 --- a/test/util/unit_sequence.cpp +++ b/test/util/unit_sequence.cpp @@ -229,6 +229,32 @@ TEST(SequenceGen, UniformSequenceZeroSize) EXPECT_TRUE((is_same::value)); } +TEST(SequenceGen, UniformSequenceSingleElement) +{ + using Result = typename uniform_sequence_gen<1, 99>::type; + using Expected = Sequence<99>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceDifferentValues) +{ + using Result1 = typename uniform_sequence_gen<3, 0>::type; + using Expected1 = Sequence<0, 0, 0>; + EXPECT_TRUE((is_same::value)); + + using Result2 = typename uniform_sequence_gen<4, -5>::type; + using Expected2 = Sequence<-5, -5, -5, -5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, UniformSequenceLargeSize) +{ + // Test with larger size to verify __make_integer_seq implementation + using Result = typename uniform_sequence_gen<16, 7>::type; + using Expected = Sequence<7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7>; + EXPECT_TRUE((is_same::value)); +} + // Test make_index_sequence TEST(SequenceGen, MakeIndexSequence) { @@ -244,6 +270,54 @@ TEST(SequenceGen, MakeIndexSequenceZero) EXPECT_TRUE((is_same::value)); } +// Test sequence_gen with custom functors +TEST(SequenceGen, SequenceGenWithDoubleFunctor) +{ + struct DoubleFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * 2; } + }; + using Result = typename sequence_gen<5, DoubleFunctor>::type; + using Expected = Sequence<0, 2, 4, 6, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenWithSquareFunctor) +{ + struct SquareFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i * i; } + }; + using Result = typename sequence_gen<5, SquareFunctor>::type; + using Expected = Sequence<0, 1, 4, 9, 16>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceGen, SequenceGenZeroSize) +{ + struct IdentityFunctor + { + __host__ __device__ constexpr index_t operator()(index_t i) const { return i; } + }; + using Result = typename sequence_gen<0, IdentityFunctor>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); + // Also verify non-zero size works with identity + using Result5 = typename sequence_gen<5, IdentityFunctor>::type; + EXPECT_TRUE((is_same>::value)); +} + +TEST(SequenceGen, SequenceGenSingleElement) +{ + struct ConstantFunctor + { + __host__ __device__ constexpr index_t operator()(index_t) const { return 42; } + }; + using Result = typename sequence_gen<1, ConstantFunctor>::type; + using Expected = Sequence<42>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_merge TEST(SequenceMerge, MergeTwoSequences) { @@ -272,6 +346,66 @@ TEST(SequenceMerge, MergeSingleSequence) EXPECT_TRUE((is_same::value)); } +TEST(SequenceMerge, MergeFourSequences) +{ + // Test the 4-sequence specialization + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2, 3>; + using Seq3 = Sequence<4, 5, 6>; + using Seq4 = Sequence<7, 8>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeFiveSequences) +{ + // Test the binary tree reduction path (5+ sequences) + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3>; + using Seq4 = Sequence<4>; + using Seq5 = Sequence<5>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeManySequences) +{ + // Test with many sequences to stress the binary tree reduction + using Seq1 = Sequence<1>; + using Seq2 = Sequence<2>; + using Seq3 = Sequence<3, 4>; + using Seq4 = Sequence<5>; + using Seq5 = Sequence<6, 7>; + using Seq6 = Sequence<8>; + using Seq7 = Sequence<9, 10>; + using Seq8 = Sequence<11, 12>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeEmptySequences) +{ + // Test merging empty sequences + using Seq1 = Sequence<>; + using Seq2 = Sequence<1, 2>; + using Seq3 = Sequence<>; + using Result = typename sequence_merge::type; + using Expected = Sequence<1, 2>; + EXPECT_TRUE((is_same::value)); +} + +TEST(SequenceMerge, MergeZeroSequences) +{ + // Test the empty specialization + using Result = typename sequence_merge<>::type; + using Expected = Sequence<>; + EXPECT_TRUE((is_same::value)); +} + // Test sequence_split TEST(SequenceSplit, SplitInMiddle) { From 917f35553a46286eb3364abec4de5267d2aa92b0 Mon Sep 17 00:00:00 2001 From: chris-tsiaousis-hpc Date: Mon, 26 Jan 2026 19:20:30 +0100 Subject: [PATCH 04/28] Remove code duplications in batched gemm (multi D) gemm (multi D) wmma (#3617) * Added common struct to enable code reduction in gemm gemm and gemm multi_d gemm multi_d wmma implementation This file includes all shared components. The (shared between the two implementations) kernel, the pointer offset computation struct, the grid descriptor creator and definitions, the invoker struct and the argument struct. Signed-off-by: Chris Tsiaousis * Used the common struct in the batched gemm gemm wmma cshuffle v3 implementation Signed-off-by: Chris Tsiaousis * Used the shared structs in the gemm multiple D gemm multiple D wmma cshuffle v3 implementation Signed-off-by: Chris Tsiaousis * Boy-scout: IWYU paradigm in the gemm gemm and gemm multiple D gemm multiple D wmma cshuffle v3 implementations Signed-off-by: Chris Tsiaousis --------- Signed-off-by: Chris Tsiaousis --- ...ice_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 618 +++--------- ...ched_gemm_gemm_wmma_cshuffle_v3_common.hpp | 902 ++++++++++++++++++ ...ple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp | 816 +++------------- 3 files changed, 1173 insertions(+), 1163 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp index 45ec3a2065..6b1144047f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -3,77 +3,21 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" #include "ck/utility/tuple.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t c_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCBasePtr(g_idx))); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - Tuple<>{}, // p_d0s_grid - arg.p_b1_grid + b1_batch_offset, - Tuple<>{}, // p_d1s_grid - arg.p_c_grid + c_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.b1_grid_desc, - Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - arg.c_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.c_element_op, - arg.block_2_ctile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes C = A * B0 * B1 // MN = MK * KL * LN // ^^^^^^ (Acc0) @@ -157,88 +101,47 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC) - : BatchStrideA_(BatchStrideA), - BatchStrideB0_(BatchStrideB0), - BatchStrideB1_(BatchStrideB1), - BatchStrideC_(BatchStrideC) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideC_); - } - - private: - index_t BatchStrideA_; - index_t BatchStrideB0_; - index_t BatchStrideB1_; - index_t BatchStrideC_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType + Tuple<>, // D1sDataType + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -260,12 +163,12 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm, // Ds0GridDesc - B1GridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, Tuple<>, // Ds1GridDesc - CGridDesc_M_N, + typename DeviceGemmGemmCommonBase::CGridDesc_M_N, // Tiling Family MPerBlock, LPerBlock, @@ -312,339 +215,67 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + Tuple<>, // D0sLayout + B1Layout, + Tuple<>, // D1sLayout + CLayout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + CDataType, + Tuple<>, // D0sDataType, + Tuple<>, // D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + ck::index_t{}, // CDE0BlockTransferSrcScalarPerVector + CShuffleBlockTransferScalarPerVector_NPerBlock, + false>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - const B1DataType* p_b1_grid_, - CDataType* p_c_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - index_t StrideB1, - index_t StrideC, - index_t BatchStrideA, - index_t BatchStrideB0, - index_t BatchStrideB1, - index_t BatchStrideC, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CElementwiseOperation c_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_b1_grid{p_b1_grid_}, - p_c_grid{p_c_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - c_element_op{c_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, BatchStrideB0, BatchStrideB1, BatchStrideC} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - c_g_m_o_lengths = arr3{batch_count, M, O}; - c_g_m_o_strides = arr3{BatchStrideC, StrideC, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides); - c_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); - block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1); - } - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - const B1DataType* p_b1_grid; - CDataType* p_c_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - arr3 c_g_m_o_lengths; - arr3 c_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CElementwiseOperation c_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - B1GridDesc b1_grid_desc; - CGridDesc_M_N c_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B layout must be Column\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - Tuple<>{}, - arg.b1_grid_desc, - Tuple<>{}, - arg.c_grid_desc_m_n, - arg.block_2_ctile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto c_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto c_stride_lowest = arg.c_g_m_o_strides[2]; - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - c_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_gemm_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b0, @@ -669,28 +300,39 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm(static_cast(p_a), - static_cast(p_b0), - static_cast(p_b1), - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideB1, - StrideC, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - BatchStrideC, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + + std::array p_d0_grid{}; + std::array p_d1_grid{}; + std::array StrideD0s{}, BatchStrideD0s{}; + std::array StrideD1s, BatchStrideD1s{}; + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0_grid, + static_cast(p_b1), + p_d1_grid, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideC, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..a739af898f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,902 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/integral_constant.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_wmma_cshuffle_v3(typename DeviceOp::Argument arg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_e1_batch_offset = + __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetCE1BasePtr(g_idx))); + + auto [p_d0s_grid, p_d1s_grid] = [&]() { + if constexpr(IsMultiD) + { + auto create_grid = [](auto NumTensor, auto func, auto& arg_grid, auto&& grid_pointer) { + static_for<0, decltype(NumTensor)::value, 1>{}([&](auto In) { + const long_index_t batch_offset = __builtin_amdgcn_readfirstlane(func(In)); + grid_pointer(In) = arg_grid(In) + batch_offset; + }); + return std::move(grid_pointer); + }; + auto get_d0_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, d_idx); + }; + auto get_d1_base_ptr = [&arg, &g_idx](auto d_idx) { + return arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, d_idx); + }; + auto d0s_grid = create_grid(ck::integral_constant{}, + get_d0_base_ptr, + arg.p_d0s_grid, + GridwiseOp::MakeD0sGridPointer()); + auto d1s_grid = create_grid(ck::integral_constant{}, + get_d1_base_ptr, + arg.p_d1s_grid, + GridwiseOp::MakeD1sGridPointer()); + return std::make_pair(d0s_grid, d1s_grid); + } + else + { + return std::make_pair(Tuple<>{}, Tuple<>{}); + } + }(); + + GridwiseOp::template Run( + arg.p_a_grid + a_batch_offset, + arg.p_b0_grid + b0_batch_offset, + p_d0s_grid, + arg.p_b1_grid + b1_batch_offset, + p_d1s_grid, + arg.p_c_e1_grid + c_e1_batch_offset, + p_shared, + arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, + arg.c_e1_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op, + arg.b0_element_op, + arg.acc_element_op, + arg.b1_element_op, + arg.cde1_element_op, + arg.block_2_etile_map); +#else + ignore = arg; +#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) +} + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common +{ + static constexpr ck::index_t NumD0Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD0Tensor; + } + return 0; + }(); + static constexpr ck::index_t NumD1Tensor = []() { + if constexpr(IsMultiD) + { + return DeviceOp::NumD1Tensor; + } + return 0; + }(); + + struct GridDescriptorCreator + { + // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler + // Transform operator or just not use one at all. + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence<1, 1, 1, 1, 1>, + Sequence, + GemmSpec, + TensorSpecialization::Default, // ASpec + TensorSpecialization::Default, // B0Spec + TensorSpecialization::Default, // B1Spec + TensorSpecialization::Default>; // CSpec + + __host__ __device__ static auto + MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, + const std::array& a_g_m_k_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, + const std::array& b0_g_l_k_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, + const std::array& b1_g_n_l_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), + Number{}); + } + + __host__ __device__ static auto + MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, + const std::array& d0_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); + } + + __host__ __device__ static auto MakeD0sGridDescriptor( + const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, + const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto MakeD1sGridDescriptor( + const std::array, NumD1Tensor>& d1_g_m_o_lengths_vec, + const std::array, NumD1Tensor>& d1_g_m_o_strides_vec) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); + }, + Number{}); + } + + __host__ __device__ static auto + MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, + const std::array& e1_g_m_n_strides_vec) + { + return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); + } + }; + + using AGridDesc = decltype(GridDescriptorCreator::MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(GridDescriptorCreator::MakeB0GridDescriptor({}, {})); + using D0sGridDesc = + remove_cvref_t; + using B1GridDesc = decltype(GridDescriptorCreator::MakeB1GridDescriptor({}, {})); + using D1sGridDesc = + remove_cvref_t; + using E1GridDesc = decltype(GridDescriptorCreator::MakeE1GridDescriptor({}, {})); + using CGridDesc_M_N = + decltype(GridDescriptorCreator::Transform::MakeCGridDescriptor_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB0, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB0_(BatchStrideB0), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_E1_(BatchStrideC) + { + } + + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideC_E1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCE1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_E1_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d0_idx]); + } + + template + __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideC_E1_; + }; +}; + +template +struct DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg +{ + using GridwiseGemm = typename DeviceOp::GridwiseOp; + using Common = + DeviceGemmGemm_Wmma_CShuffleV3_Common; + + static constexpr auto NumD0Tensor = Common::NumD0Tensor; + static constexpr auto NumD1Tensor = Common::NumD1Tensor; + + struct Argument : public BaseArgument + { + using arr3 = std::array; + + Argument(const ADataType* p_a_grid_, + const B0DataType* p_b0_grid_, + std::array p_d0s_grid_, + const B1DataType* p_b1_grid_, + std::array p_d1s_grid_, + CE1DataType* p_e1_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t O_, + index_t Batch, + index_t StrideA, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + AElementwiseOperation a_element_op_, + B0ElementwiseOperation b0_element_op_, + AccElementwiseOperation acc_element_op_, + B1ElementwiseOperation b1_element_op_, + CDE1ElementwiseOperation cde1_element_op_) + : p_a_grid{p_a_grid_}, + p_b0_grid{p_b0_grid_}, + p_d0s_grid{}, + p_b1_grid{p_b1_grid_}, + p_d1s_grid{}, + p_c_e1_grid{p_e1_grid_}, + M{M_}, + N{N_}, + K{K_}, + O{O_}, + batch_count{Batch}, + a_element_op{a_element_op_}, + b0_element_op{b0_element_op_}, + acc_element_op{acc_element_op_}, + b1_element_op{b1_element_op_}, + cde1_element_op{cde1_element_op_}, + compute_base_ptr_of_batch{BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + + a_g_m_k_lengths = arr3{batch_count, M, K}; + a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] + + b0_g_n_k_lengths = arr3{batch_count, N, K}; + b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] + + b1_g_o_n_lengths = arr3{batch_count, O, N}; + b1_g_o_n_strides = + is_same_v + ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] + : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] + + e1_g_m_o_lengths = arr3{batch_count, M, O}; + e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] + + a_grid_desc = Common::GridDescriptorCreator::MakeAGridDescriptor(a_g_m_k_lengths, + a_g_m_k_strides); + b0_grid_desc = Common::GridDescriptorCreator::MakeB0GridDescriptor(b0_g_n_k_lengths, + b0_g_n_k_strides); + b1_grid_desc = Common::GridDescriptorCreator::MakeB1GridDescriptor(b1_g_o_n_lengths, + b1_g_o_n_strides); + c_e1_grid_desc_m_n = Common::GridDescriptorCreator::MakeE1GridDescriptor( + e1_g_m_o_lengths, e1_g_m_o_strides); + c_e1_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_e1_grid_desc_m_n); + block_2_etile_map = GridwiseGemm::MakeDefaultBlock2ETileMap(c_e1_grid_desc_m_n, 1, 1); + + if constexpr(IsMultiD) + { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + + // D0s layout [batch_count, M, N] + d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; + d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; + + // D0 pointer + p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); + }); + // D0 desc + d0s_grid_desc = Common::GridDescriptorCreator::MakeD0sGridDescriptor( + d0s_g_m_n_lengths, d0s_g_m_n_strides); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1DataType = remove_cvref_t>; + + // D1s layout [batch_count, M, O] + d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; + d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; + + // D1 pointer + p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); + }); + // D1 desc + d1s_grid_desc = Common::GridDescriptorCreator::MakeD1sGridDescriptor( + d1s_g_m_o_lengths, d1s_g_m_o_strides); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc); + } + } + + // Pointers + const ADataType* p_a_grid; + const B0DataType* p_b0_grid; + typename GridwiseGemm::D0sGridPointer p_d0s_grid; + const B1DataType* p_b1_grid; + typename GridwiseGemm::D1sGridPointer p_d1s_grid; + CE1DataType* p_c_e1_grid; + + // Raw Problem Size + index_t M; + index_t N; + index_t K; + index_t O; + index_t batch_count; + + arr3 a_g_m_k_lengths; + arr3 a_g_m_k_strides; + arr3 b0_g_n_k_lengths; + arr3 b0_g_n_k_strides; + std::array d0s_g_m_n_lengths; + std::array d0s_g_m_n_strides; + arr3 b1_g_o_n_lengths; + arr3 b1_g_o_n_strides; + std::array d1s_g_m_o_lengths; + std::array d1s_g_m_o_strides; + arr3 e1_g_m_o_lengths; + arr3 e1_g_m_o_strides; + + AElementwiseOperation a_element_op; + B0ElementwiseOperation b0_element_op; + AccElementwiseOperation acc_element_op; + B1ElementwiseOperation b1_element_op; + CDE1ElementwiseOperation cde1_element_op; + + // Grid descriptors and other mem calculators + typename Common::AGridDesc a_grid_desc; + typename Common::B0GridDesc b0_grid_desc; + std::conditional_t> d0s_grid_desc; + typename Common::B1GridDesc b1_grid_desc; + typename Common::D1sGridDesc d1s_grid_desc; + std::conditional_t< + IsMultiD, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + Tuple<>> + d1s_grid_desc_mblock_mperblock_nblock_nperblock; + + std::conditional_t + c_e1_grid_desc_m_n; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_e1_grid_desc_mblock_mperblock_nblock_nperblock; + + typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map; + + typename Common::ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); + + const index_t grid_size = arg.batch_count * M0 * N0; + + auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { + constexpr bool has_loop = decltype(has_main_k_block_loop)::value; + constexpr TailNumber tail_num = decltype(tail_number)::value; + const auto kernel = kernel_batched_gemm_gemm_wmma_cshuffle_v3; + return launch_and_time_kernel( + stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); + }; + + bool HasMainKBlockLoop = GridwiseGemm::CalculateHasMainKBlockLoop(arg.K); + TailNumber TailNum = GridwiseGemm::CalculateKBlockLoopTailNum(arg.K); + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); + return 0.0f; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(HasMainKBlockLoop && TailNum == TailNumber::Full) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) + { + return launch_kernel(std::integral_constant{}, + std::integral_constant{}); + } + else + { + printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); + return 0.0f; + } + } + else + { + printf("Invalid pipeline version!\n"); + return 0.0f; + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + // check if DsLayout is supported + template + static constexpr bool CheckDLayout() + { + bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Print lambda with env check and printf() style formmating. + const char* curFunc = __func__; + auto print = [&curFunc](const char* format, ...) -> void { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { +#if defined(__clang__) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wformat-nonliteral" +#endif + va_list args; + va_start(args, format); + std::vfprintf(stdout, format, args); + va_end(args); +#if defined(__clang__) +#pragma clang diagnostic pop +#endif + std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; + } + }; + + if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + print("DeviceOp: Arch err\n"); + return false; + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + print("DeviceOp: gfx 11 does not support fp8\n"); + return false; + } + } + + if constexpr(!(is_same_v || is_same_v)) + { + print("DeviceOp: Acc0 Type err\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: A layout must be Row\n"); + return false; + } + + if constexpr(!(is_same_v || + is_same_v)) + { + print("DeviceOp: B1 layout must be Column or Row\n"); + return false; + } + + if constexpr(!(is_same_v)) + { + print("DeviceOp: C layout must be Row\n"); + return false; + } + + // Other padding modes have not been tested and do not get checked individually. + if constexpr(GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MNKOPadding) + { + print("Padding mode must be default or MNKO\n"); + return false; + } + + // Per wmma dimensions not equal to 16 are very untested. + if constexpr(MPerWmma != 16 || LPerWmma != 16 || DeviceOp::NPerWmma != 16) + { + print("M, L, N per Wmma must be 16\n"); + return false; + } + + if constexpr(IsMultiD) + { + if constexpr(!(is_same_v)) + { + print("DeviceOp: B0 layout must be Column\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D0s layout must be Row\n"); + return false; + } + + if constexpr(!(CheckDLayout())) + { + print("DeviceOp: All D1s layout must be Row\n"); + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.d0s_grid_desc, + arg.b1_grid_desc, + arg.d1s_grid_desc, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto cde1_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; + + // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major + // and the lowest dimension stride is hardcoded to 1 + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + e1_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + else + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + Tuple<>{}, + arg.b1_grid_desc, + Tuple<>{}, + arg.c_e1_grid_desc_m_n, + arg.block_2_etile_map)) + { + return false; + } + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; + const auto c_extent_lowest = arg.O; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + print("DeviceOp: Data Transfer Vector scalar err\n"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; + const auto b0_stride_lowest = B0BlockTransferSrcVectorDim == 2 + ? arg.b0_g_n_k_strides[2] + : arg.b0_g_n_k_strides[1]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? arg.b1_g_o_n_strides[2] + : arg.b1_g_o_n_strides[1]; + const auto c_stride_lowest = arg.e1_g_m_o_strides[2]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + print("DeviceOp: Data Vectorize transfer err\n"); + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) + { + return false; + } + + return true; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp index 06651c0c0e..83fec9c95f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -3,91 +3,20 @@ #pragma once -#include #include -#include #include #include #include "ck/ck.hpp" -#include "ck/utility/common_header.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_wmma_cshuffle_v3_common.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp" -#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" namespace ck { namespace tensor_operation { namespace device { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3(typename DeviceOp::RawArg arg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) - - __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; - const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / arg.batch_count); - const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - - const long_index_t a_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetABasePtr(g_idx))); - const long_index_t b0_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); - const long_index_t b1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); - const long_index_t e1_batch_offset = - __builtin_amdgcn_readfirstlane((arg.compute_base_ptr_of_batch.GetE1BasePtr(g_idx))); - - auto p_d0s_grid = GridwiseOp::MakeD0sGridPointer(); - auto p_d1s_grid = GridwiseOp::MakeD1sGridPointer(); - - static_for<0, DeviceOp::NumD0Tensor, 1>{}([&](auto In) { - const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); - p_d0s_grid(In) = arg.p_d0s_grid(In) + d0_batch_offset; - }); - - static_for<0, DeviceOp::NumD1Tensor, 1>{}([&](auto In) { - const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(arg.compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); - p_d1s_grid(In) = arg.p_d1s_grid(In) + d1_batch_offset; - }); - - GridwiseOp::template Run( - arg.p_a_grid + a_batch_offset, - arg.p_b0_grid + b0_batch_offset, - p_d0s_grid, - arg.p_b1_grid + b1_batch_offset, - p_d1s_grid, - arg.p_e1_grid + e1_batch_offset, - p_shared, - arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock, - arg.e1_grid_desc_mblock_mperblock_nblock_nperblock, - arg.a_element_op, - arg.b0_element_op, - arg.acc_element_op, - arg.b1_element_op, - arg.cde1_element_op, - arg.block_2_etile_map); -#else - ignore = arg; -#endif // (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__) -} - // Computes: // Acc = Acc_Op(A_Op(A) * B0_Op(B0), D0_0, D0_1, ...) // E = CDE1_Op(Acc_Op(Acc0) * B1_Op(B1), D1_0, D1_1, ...) @@ -184,151 +113,51 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 static constexpr index_t NumD0Tensor = D0sDataType::Size(); static constexpr index_t NumD1Tensor = D1sDataType::Size(); - static constexpr auto I0 = Number<0>{}; - // To match XDL implementation NPerWmma (A.k.a Gemm1 NPerWmma) is set equal // to LPerWmma (A.k.a Gemm0 NPerWmma). static constexpr index_t NPerWmma = LPerWmma; - // TODO: Now that we are no longer using NumDim or TensorSpec, we can probably use a simpler - // Transform operator or just not use one at all. - using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< - Sequence<1, 1, 1, 1, 1>, - Sequence, - GemmSpec, - TensorSpecialization::Default, // ASpec - TensorSpecialization::Default, // B0Spec - TensorSpecialization::Default, // B1Spec - TensorSpecialization::Default>; // CSpec - - __host__ __device__ static auto - MakeAGridDescriptor(const std::array& a_g_m_k_lengths_vec, - const std::array& a_g_m_k_strides_vec) - { - return Transform::MakeAGridDescriptor_AK0_M_AK1( - Transform::MakeAGridDescriptor_M_K(a_g_m_k_lengths_vec, a_g_m_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB0GridDescriptor(const std::array& b0_g_l_k_lengths_vec, - const std::array& b0_g_l_k_strides_vec) - { - return Transform::MakeB0GridDescriptor_BK0_N_BK1( - Transform::MakeB0GridDescriptor_N_K(b0_g_l_k_lengths_vec, b0_g_l_k_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeB1GridDescriptor(const std::array& b1_g_n_l_lengths_vec, - const std::array& b1_g_n_l_strides_vec) - { - return Transform::MakeB1GridDescriptor_BK0_N_BK1( - Transform::MakeB1GridDescriptor_N_K(b1_g_n_l_lengths_vec, b1_g_n_l_strides_vec), - Number{}); - } - - __host__ __device__ static auto - MakeD0GridDescriptor(const std::array& d0_g_m_n_lengths_vec, - const std::array& d0_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(d0_g_m_n_lengths_vec, d0_g_m_n_strides_vec); - } - - __host__ __device__ static auto MakeD0sGridDescriptor( - const std::array, NumD0Tensor>& d0_g_m_n_lengths_vec, - const std::array, NumD0Tensor>& d0_g_m_n_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeD0GridDescriptor(d0_g_m_n_lengths_vec[i], d0_g_m_n_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto MakeD1sGridDescriptor( - const std::array, NumD0Tensor>& d1_g_m_o_lengths_vec, - const std::array, NumD0Tensor>& d1_g_m_o_strides_vec) - { - return generate_tuple( - [&](auto i) { - return MakeE1GridDescriptor(d1_g_m_o_lengths_vec[i], d1_g_m_o_strides_vec[i]); - }, - Number{}); - } - - __host__ __device__ static auto - MakeE1GridDescriptor(const std::array& e1_g_m_n_lengths_vec, - const std::array& e1_g_m_n_strides_vec) - { - return Transform::MakeCGridDescriptor_M_N(e1_g_m_n_lengths_vec, e1_g_m_n_strides_vec); - } - - using AGridDesc = decltype(MakeAGridDescriptor({}, {})); - using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); - using D0sGridDesc = remove_cvref_t; - using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); - using D1sGridDesc = remove_cvref_t; - using E1GridDesc = decltype(MakeE1GridDescriptor({}, {})); - - struct ComputeBasePtrOfStridedBatch - { - ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1) - : BatchStrideA0_(BatchStrideA0), - BatchStrideB0_(BatchStrideB0), - BatchStrideD0s_(BatchStrideD0s), - BatchStrideB1_(BatchStrideB1), - BatchStrideD1s_(BatchStrideD1s), - BatchStrideE1_(BatchStrideE1) - { - } - - __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideA0_); - } - - __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB0_); - } - - template - __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, - Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD0s_[d1_idx]); - } - - __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideB1_); - } - - __host__ __device__ constexpr long_index_t GetE1BasePtr(index_t g_idx) const - { - return g_idx * static_cast(BatchStrideE1_); - } - - template - __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const - { - return g_idx * static_cast(BatchStrideD1s_[d1_idx]); - } - - private: - index_t BatchStrideA0_; - index_t BatchStrideB0_; - std::array BatchStrideD0s_; - index_t BatchStrideB1_; - std::array BatchStrideD1s_; - index_t BatchStrideE1_; - }; + using DeviceGemmGemmCommonBase = + DeviceGemmGemm_Wmma_CShuffleV3_Common; // IsMultiD // GridwiseOp using GridwiseOp = GridwiseBatchedGemmGemm_wmma_cshuffle_v3< @@ -350,12 +179,12 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CDE1ElementwiseOperation, InMemoryDataOperationEnum::Set, // InMemory Data Descriptor - AGridDesc, - B0GridDesc, - D0sGridDesc, - B1GridDesc, - D1sGridDesc, - E1GridDesc, + typename DeviceGemmGemmCommonBase::AGridDesc, + typename DeviceGemmGemmCommonBase::B0GridDesc, + typename DeviceGemmGemmCommonBase::D0sGridDesc, + typename DeviceGemmGemmCommonBase::B1GridDesc, + typename DeviceGemmGemmCommonBase::D1sGridDesc, + typename DeviceGemmGemmCommonBase::E1GridDesc, // Tiling Family MPerBlock, LPerBlock, @@ -402,430 +231,67 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, - Transform::matrix_padder.PadN, + DeviceGemmGemmCommonBase::GridDescriptorCreator::Transform::matrix_padder.PadN, BlkGemmPipeSched, BlkGemmPipelineVer>; - struct RawArg : public BaseArgument + using DeviceGemmGemmCommon = DeviceGemmGemm_Wmma_CShuffleV3_Common_Invoker_Arg< + DeviceOp, + GemmSpec, + ALayout, + B0layout, + D0sLayout, + B1Layout, + D1sLayout, + E1Layout, + BlockSize, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock, + ADataType, + B0DataType, + B1DataType, + AccDataType, + E1DataType, + D0sDataType, + D1sDataType, + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + AK1, + BK1, + L1, + MPerWmma, + LPerWmma, + BlkGemmPipelineVer, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + CDE0BlockTransferSrcScalarPerVector, + CShuffleBlockTransferScalarPerVector_NPerBlock, + true>; // IsMultiD + // Invoker + using Invoker = typename DeviceGemmGemmCommon::Invoker; + + // Argument + using Argument = typename DeviceGemmGemmCommon::Argument; + + static bool IsSupportedArgument(const Argument& arg) { - using arr3 = std::array; - - RawArg(const ADataType* p_a_grid_, - const B0DataType* p_b0_grid_, - std::array p_d0s_grid_, - const B1DataType* p_b1_grid_, - std::array p_d1s_grid_, - E1DataType* p_e1_grid_, - index_t M_, - index_t N_, - index_t K_, - index_t O_, - index_t Batch, - index_t StrideA, - index_t StrideB0, - std::array StrideD0s, - index_t StrideB1, - std::array StrideD1s, - index_t StrideE1, - index_t BatchStrideA, - index_t BatchStrideB0, - std::array BatchStrideD0s, - index_t BatchStrideB1, - std::array BatchStrideD1s, - index_t BatchStrideE1, - AElementwiseOperation a_element_op_, - B0ElementwiseOperation b0_element_op_, - AccElementwiseOperation acc_element_op_, - B1ElementwiseOperation b1_element_op_, - CDE1ElementwiseOperation cde1_element_op_) - : p_a_grid{p_a_grid_}, - p_b0_grid{p_b0_grid_}, - p_d0s_grid{}, - p_b1_grid{p_b1_grid_}, - p_d1s_grid{}, - p_e1_grid{p_e1_grid_}, - M{M_}, - N{N_}, - K{K_}, - O{O_}, - batch_count{Batch}, - a_element_op{a_element_op_}, - b0_element_op{b0_element_op_}, - acc_element_op{acc_element_op_}, - b1_element_op{b1_element_op_}, - cde1_element_op{cde1_element_op_}, - compute_base_ptr_of_batch{BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1} - { - - a_g_m_k_lengths = arr3{batch_count, M, K}; - a_g_m_k_strides = arr3{BatchStrideA, StrideA, 1}; // A layout [batch_count, M, K] - - b0_g_n_k_lengths = arr3{batch_count, N, K}; - b0_g_n_k_strides = arr3{BatchStrideB0, StrideB0, 1}; // B0 layout [batch_count, N, K] - - b1_g_o_n_lengths = arr3{batch_count, O, N}; - b1_g_o_n_strides = - is_same_v - ? arr3{BatchStrideB1, 1, StrideB1} // B1 layout [batch_count, N, O] - : arr3{BatchStrideB1, StrideB1, 1}; // B1 layout [batch_count, O, N] - - e1_g_m_o_lengths = arr3{batch_count, M, O}; - e1_g_m_o_strides = arr3{BatchStrideE1, StrideE1, 1}; // C layout [batch_count, M, O] - - a_grid_desc = MakeAGridDescriptor(a_g_m_k_lengths, a_g_m_k_strides); - b0_grid_desc = MakeB0GridDescriptor(b0_g_n_k_lengths, b0_g_n_k_strides); - b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides); - e1_grid_desc_m_n = MakeE1GridDescriptor(e1_g_m_o_lengths, e1_g_m_o_strides); - e1_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e1_grid_desc_m_n); - block_2_etile_map = GridwiseOp::MakeDefaultBlock2ETileMap(e1_grid_desc_m_n, 1, 1); - - static_for<0, NumD0Tensor, 1>{}([&](auto i) { - using D0DataType = remove_cvref_t>; - - // D0s layout [batch_count, M, N] - d0s_g_m_n_lengths[i] = arr3{batch_count, M, N}; - d0s_g_m_n_strides[i] = arr3{BatchStrideD0s[i], StrideD0s[i], 1}; - - // D0 pointer - p_d0s_grid(i) = static_cast(p_d0s_grid_[i]); - - // D0 desc - d0s_grid_desc(i) = MakeD0GridDescriptor(d0s_g_m_n_lengths[i], d0s_g_m_n_strides[i]); - }); - - static_for<0, NumD1Tensor, 1>{}([&](auto i) { - using D1DataType = remove_cvref_t>; - - // D1s layout [batch_count, M, O] - d1s_g_m_o_lengths[i] = arr3{batch_count, M, O}; - d1s_g_m_o_strides[i] = arr3{BatchStrideD1s[i], StrideD1s[i], 1}; - - // D1 pointer - p_d1s_grid(i) = static_cast(p_d1s_grid_[i]); - - // D1 desc - d1s_grid_desc(i) = MakeE1GridDescriptor(d1s_g_m_o_lengths[i], d1s_g_m_o_strides[i]); - }); - - d1s_grid_desc_mblock_mperblock_nblock_nperblock = - GridwiseOp::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(d1s_grid_desc); - } - - // Pointers - const ADataType* p_a_grid; - const B0DataType* p_b0_grid; - typename GridwiseOp::D0sGridPointer p_d0s_grid; - const B1DataType* p_b1_grid; - typename GridwiseOp::D1sGridPointer p_d1s_grid; - E1DataType* p_e1_grid; - - // Raw Problem Size - index_t M; - index_t N; - index_t K; - index_t O; - index_t batch_count; - - arr3 a_g_m_k_lengths; - arr3 a_g_m_k_strides; - arr3 b0_g_n_k_lengths; - arr3 b0_g_n_k_strides; - std::array d0s_g_m_n_lengths; - std::array d0s_g_m_n_strides; - arr3 b1_g_o_n_lengths; - arr3 b1_g_o_n_strides; - std::array d1s_g_m_o_lengths; - std::array d1s_g_m_o_strides; - arr3 e1_g_m_o_lengths; - arr3 e1_g_m_o_strides; - - AElementwiseOperation a_element_op; - B0ElementwiseOperation b0_element_op; - AccElementwiseOperation acc_element_op; - B1ElementwiseOperation b1_element_op; - CDE1ElementwiseOperation cde1_element_op; - - // Grid descriptors and other mem calculators - AGridDesc a_grid_desc; - B0GridDesc b0_grid_desc; - D0sGridDesc d0s_grid_desc; - B1GridDesc b1_grid_desc; - D1sGridDesc d1s_grid_desc; - typename GridwiseOp::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock; - - E1GridDesc e1_grid_desc_m_n; - typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock; - - typename GridwiseOp::DefaultBlock2ETileMap block_2_etile_map; - - ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch; - }; - - // check if DsLayout is supported - template - static constexpr bool CheckDLayout() - { - bool valid = true; - // iterate over DLayout tuple - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DLayout = remove_cvref_t>; - // if RefLayout and DLayout are same, keep valid true, otherwise false - valid = valid && is_same_v; - }); - return valid; + return DeviceGemmGemmCommon::IsSupportedArgument(arg); } - - static bool IsSupportedArgument([[maybe_unused]] const RawArg& arg) - { - // Print lambda with env check and printf() style formmating. - const char* curFunc = __func__; - auto print = [&curFunc](const char* format, ...) -> void { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { -#if defined(__clang__) -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wformat-nonliteral" -#endif - va_list args; - va_start(args, format); - std::vfprintf(stdout, format, args); - va_end(args); -#if defined(__clang__) -#pragma clang diagnostic pop -#endif - std::cout << "In file: " << __FILE__ << ", function: " << curFunc << "\n"; - } - }; - - if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported())) - { - print("DeviceOp: Arch err\n"); - return false; - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - print("DeviceOp: gfx 11 does not support fp8\n"); - return false; - } - } - - if constexpr(!(is_same_v || is_same_v)) - { - print("DeviceOp: Acc0 Type err\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: A layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: B0 layout must be Column\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D0s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v || - is_same_v)) - { - print("DeviceOp: B1 layout must be Column or Row\n"); - return false; - } - - if constexpr(!(CheckDLayout())) - { - print("DeviceOp: All D1s layout must be Row\n"); - return false; - } - - if constexpr(!(is_same_v)) - { - print("DeviceOp: C layout must be Row\n"); - return false; - } - - // Other padding modes have not been tested and do not get checked individually. - if constexpr(GemmSpec != GemmSpecialization::Default && - GemmSpec != GemmSpecialization::MNKOPadding) - { - print("Padding mode must be default or MNKO\n"); - return false; - } - - // Per wmma dimensions not equal to 16 are very untested. - if constexpr(MPerWmma != 16 || LPerWmma != 16 || NPerWmma != 16) - { - print("M, L, N per Wmma must be 16\n"); - return false; - } - - if(!GridwiseOp::CheckValidity(arg.a_grid_desc, - arg.b0_grid_desc, - arg.d0s_grid_desc, - arg.b1_grid_desc, - arg.d1s_grid_desc, - arg.e1_grid_desc_m_n, - arg.block_2_etile_map)) - { - return false; - } - - // Check scalar per vector requirement - const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? arg.K : arg.M; - const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? arg.K : arg.N; - const auto cde0_extent_lowest = arg.N; // D0 tensors forced to be row-major - const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? arg.N : arg.O; - const auto cde1_extent_lowest = arg.O; - - if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && - b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && - cde0_extent_lowest % CDE0BlockTransferSrcScalarPerVector == 0 && - b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && - cde1_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) - { - print("DeviceOp: Data Transfer Vector scalar err\n"); - return false; - } - - // Check vector load/store requirement - const auto a_stride_lowest = - ABlockTransferSrcVectorDim == 2 ? arg.a_g_m_k_strides[2] : arg.a_g_m_k_strides[1]; - const auto b0_stride_lowest = - B0BlockTransferSrcVectorDim == 2 ? arg.b0_g_n_k_strides[2] : arg.b0_g_n_k_strides[1]; - const auto b1_stride_lowest = - B1BlockTransferSrcVectorDim == 2 ? arg.b1_g_o_n_strides[2] : arg.b1_g_o_n_strides[1]; - const auto e1_stride_lowest = arg.e1_g_m_o_strides[2]; - - // NOTE: We don't check D0s/D1s stride, as they are already forced to be row-major - // and the lowest dimension stride is hardcoded to 1 - - if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || - e1_stride_lowest == 1)) - { - print("DeviceOp: Data Vectorize transfer err\n"); - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MNKOPadding)) - { - return false; - } - - return true; - } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { - return IsSupportedArgument(*dynamic_cast(p_arg)); + return DeviceGemmGemmCommon::IsSupportedArgument(*dynamic_cast(p_arg)); } - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::RawArg; - - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - const auto M0 = math::integer_divide_ceil(arg.M, MPerBlock); - const auto N0 = math::integer_divide_ceil(arg.O, NPerBlock); - - const index_t grid_size = arg.batch_count * M0 * N0; - - auto launch_kernel = [&](auto has_main_k_block_loop, auto tail_number) { - constexpr bool has_loop = decltype(has_main_k_block_loop)::value; - constexpr TailNumber tn = tail_number; - - const auto kernel = - kernel_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3; - - return launch_and_time_kernel( - stream_config, kernel, dim3(grid_size), dim3(BlockSize), 0, arg); - }; - - bool HasMainKBlockLoop = GridwiseOp::CalculateHasMainKBlockLoop(arg.K); - TailNumber TailNum = GridwiseOp::CalculateKBlockLoopTailNum(arg.K); - - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V1!\n"); - return 0.0f; - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(HasMainKBlockLoop && TailNum == TailNumber::Full) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Even) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else if(!HasMainKBlockLoop && TailNum == TailNumber::Odd) - { - return launch_kernel(std::integral_constant{}, - std::integral_constant{}); - } - else - { - printf("Invalid HasMainKBlockLoop and TailNum combination for V3!\n"); - return 0.0f; - } - } - else - { - printf("Invalid pipeline version!\n"); - return 0.0f; - } - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - static auto MakeArgument(const ADataType* p_a0, const B0DataType* p_b0, std::array p_d0s, @@ -855,20 +321,20 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation cde1_element_op) { - return RawArg{p_a0, p_b0, - p_d0s, p_b1, - p_d1s, p_e1, - MRaw, NRaw, - KRaw, Gemm1NRaw, - Batch, StrideA0, - StrideB0, StrideD0s, - StrideB1, StrideD1s, - StrideE1, BatchStrideA0, - BatchStrideB0, BatchStrideD0s, - BatchStrideB1, BatchStrideD1s, - BatchStrideE1, a0_element_op, - b0_element_op, cde0_element_op, - b1_element_op, cde1_element_op}; + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; } // polymorphic @@ -902,34 +368,34 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3 B1ElementwiseOperation b1_element_op, CDE1ElementwiseOperation c_element_op) override { - return std::make_unique(static_cast(p_a), - static_cast(p_b0), - p_d0s, - static_cast(p_b1), - p_d1s, - static_cast(p_c), - M, - N, - K, - O, - Batch, - StrideA, - StrideB0, - StrideD0s, - StrideB1, - StrideD1s, - StrideE1, - BatchStrideA, - BatchStrideB0, - BatchStrideD0s, - BatchStrideB1, - BatchStrideD1s, - BatchStrideE1, - a_element_op, - b0_element_op, - acc_element_op, - b1_element_op, - c_element_op); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_c), + M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); } static auto MakeInvoker() { return Invoker{}; } From 834642202c0cb39df1b96dacc24d5c3b3d97e62c Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Mon, 26 Jan 2026 20:23:26 +0200 Subject: [PATCH 05/28] Re enable f8 x bf8 tests on compv3 and compv4 (#3605) * Re-enable f8 x bf8 tests on CompV3 as they now pass * On CompV4, fp8 x bf8 tests now pass with K_BlockSize I32 * Add a changelog entry --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- CHANGELOG.md | 1 + test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp | 9 ++------- test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++---- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f17a4d768..c99fc1d065 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. * Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM. diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp index ebe17aadd6..016f7be60d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv3.cpp @@ -13,13 +13,8 @@ class TestCkTileGemmPipelineCompV3 static constexpr bool check_data_type() { using Base = TestCkTileGemmPipeline>; - if constexpr(std::is_same_v && - std::is_same_v) - { - return false; - } - else if constexpr(std::is_same_v && - std::is_same_v) + if constexpr(std::is_same_v && + std::is_same_v) { return false; } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 334e360eb5..4bef581254 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -170,7 +170,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -180,7 +180,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -190,7 +190,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, @@ -200,7 +200,7 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, - std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV4>, + std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4>, std::tuple< Col, Col, Row, BF8, I4, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> From 3900e1e7ceacfa32cb8d1522260ed30befd4dae3 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Mon, 26 Jan 2026 10:29:28 -0800 Subject: [PATCH 06/28] Solve the CTAD regression & add up the Shell file for the docker management in testing (#3634) * Finished the work * Fix the pipeline --- ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 2 +- .../ck_tile/ops/reduce/block/block_reduce.hpp | 4 - .../ops/softmax/block/block_softmax_2d.hpp | 2 +- script/tools/ck-build | 143 +++++++++++++++ script/tools/ck-clean | 113 ++++++++++++ script/tools/ck-exec | 111 ++++++++++++ script/tools/ck-logs | 134 ++++++++++++++ script/tools/ck-shell | 84 +++++++++ script/tools/ck-start | 103 +++++++++++ script/tools/ck-status | 153 ++++++++++++++++ script/tools/ck-stop | 141 +++++++++++++++ script/tools/ck-test | 166 ++++++++++++++++++ 12 files changed, 1150 insertions(+), 6 deletions(-) create mode 100755 script/tools/ck-build create mode 100755 script/tools/ck-clean create mode 100755 script/tools/ck-exec create mode 100755 script/tools/ck-logs create mode 100755 script/tools/ck-shell create mode 100755 script/tools/ck-start create mode 100755 script/tools/ck-status create mode 100755 script/tools/ck-stop create mode 100755 script/tools/ck-test diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index c4ab1d4a78..34d18cb8e1 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -227,7 +227,7 @@ struct MXFlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy sequence<1>>{}); else return make_static_tile_distribution( - tile_distribution_encoding< // + tile_distribution_encoding< sequence, tuple, sequence>, diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 4284e7622f..3f59e2d036 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -392,8 +392,4 @@ struct BlockReduce2D InDataType reduce_init; }; -// deduction guide -template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; - } // namespace ck_tile diff --git a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp index abb95934ff..58e768b319 100644 --- a/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp +++ b/include/ck_tile/ops/softmax/block/block_softmax_2d.hpp @@ -40,7 +40,7 @@ struct BlockSoftmax2D #endif // compute row max - auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; + auto reduce_row_max = BlockReduce2D{x, -numeric::infinity()}; #if _BLOCK_SOFTMAX_USE_UNPACK2 auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{}); #else diff --git a/script/tools/ck-build b/script/tools/ck-build new file mode 100755 index 0000000000..2c0bb24eda --- /dev/null +++ b/script/tools/ck-build @@ -0,0 +1,143 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Build - Build Composable Kernel targets in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Build - Build Composable Kernel targets in Docker + +Usage: ck-build [options] [target...] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + -j Parallel jobs (passed to ninja) + --clean Clean before building + +Arguments: + target Target(s) to build (default: all) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-build # Build all targets + ck-build test_amdgcn_mma # Build specific target + ck-build test_amdgcn_mma test_gemm # Build multiple targets + ck-build --reconfigure # Reconfigure CMake and build all + ck-build --clean test_amdgcn_mma # Clean and build target + ck-build -j 8 test_amdgcn_mma # Build with 8 parallel jobs + +EOF +} + +# Parse arguments +targets=() +reconfigure=false +clean=false +parallel_jobs="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --clean) + clean=true + shift + ;; + -j) + parallel_jobs="-j $2" + shift 2 + ;; + *) + targets+=("$1") + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Clean if requested +if [ "$clean" = true ]; then + echo "Cleaning build directory..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "" +fi + +# Build targets +if [ ${#targets[@]} -eq 0 ]; then + echo "Building all configured targets..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} 2>&1 + " +else + echo "Building targets: ${targets[*]}" + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${parallel_jobs} ${targets[*]} 2>&1 + " +fi + +echo "" +echo "Build complete ✓" diff --git a/script/tools/ck-clean b/script/tools/ck-clean new file mode 100755 index 0000000000..4b422f81f4 --- /dev/null +++ b/script/tools/ck-clean @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Clean - Clean build artifacts in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Clean - Clean build artifacts in Docker container + +Usage: ck-clean [options] + +Options: + -h, --help Show this help message + --name Specify container name + --all Remove entire build directory + -f, --force Force without confirmation + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-clean # Clean build artifacts (ninja clean) + ck-clean --all # Remove entire build directory + ck-clean --force --all # Remove build directory without confirmation + +EOF +} + +# Parse arguments +remove_all=false +force=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + remove_all=true + shift + ;; + -f|--force) + force=true + shift + ;; + *) + echo "Unknown option: $1" + show_help + exit 1 + ;; + esac +done + +# Check if container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running" + echo "Start with: ck-start" + exit 1 +fi + +# Check if build directory exists +if ! docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + echo "Build directory does not exist" + exit 0 +fi + +if [ "$remove_all" = true ]; then + # Remove entire build directory + if [ "$force" = false ]; then + read -p "Remove entire build directory? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "Removing build directory..." + docker exec "${CONTAINER_NAME}" bash -c "rm -rf /workspace/build" + echo "Build directory removed ✓" +else + # Clean with ninja + if ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Build not configured (build.ninja not found)" + echo "Use --all to remove build directory" + exit 1 + fi + + echo "Cleaning build artifacts..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja clean + " + echo "Build artifacts cleaned ✓" +fi diff --git a/script/tools/ck-exec b/script/tools/ck-exec new file mode 100755 index 0000000000..dfc7655774 --- /dev/null +++ b/script/tools/ck-exec @@ -0,0 +1,111 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Exec - Execute arbitrary commands in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Exec - Execute arbitrary commands in Docker container + +Usage: ck-exec [options] [args...] + +Options: + -h, --help Show this help message + --name Specify container name + -w Working directory (default: /workspace) + -i, --interactive Interactive mode (allocate TTY) + +Arguments: + command Command to execute (required) + args Arguments to the command + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-exec rocm-smi # Run rocm-smi + ck-exec rocminfo # Run rocminfo + ck-exec ls -la build/bin # List build binaries + ck-exec -w /workspace/build ninja -t commands # Run ninja commands + ck-exec --interactive python3 # Interactive Python session + +Common Commands: + ck-exec rocm-smi # Check GPU status + ck-exec rocminfo \| grep gfx # Check GPU architecture + ck-exec hipcc --version # Check HIP compiler version + ck-exec cmake --version # Check CMake version + ck-exec ninja -C build -t targets # List all build targets + +EOF +} + +# Parse arguments +workdir="/workspace" +interactive=false +command_args=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -w) + workdir="$2" + shift 2 + ;; + -i|--interactive) + interactive=true + shift + ;; + *) + command_args+=("$1") + shift + ;; + esac +done + +# Validate command +if [ ${#command_args[@]} -eq 0 ]; then + echo "Error: command required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Build command string +cmd_string="" +for arg in "${command_args[@]}"; do + cmd_string="${cmd_string} $(printf '%q' "$arg")" +done + +# Execute command +if [ "$interactive" = true ]; then + docker exec -it -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +else + docker exec -w "${workdir}" "${CONTAINER_NAME}" bash -c "${cmd_string}" +fi diff --git a/script/tools/ck-logs b/script/tools/ck-logs new file mode 100755 index 0000000000..cfad23b3b5 --- /dev/null +++ b/script/tools/ck-logs @@ -0,0 +1,134 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Logs - View container logs and build output + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Logs - View container logs and build output + +Usage: ck-logs [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -f, --follow Follow log output + -n, --tail Show last N lines (default: 100) + --cmake Show CMake configuration log + --build Show last build log + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-logs # Show last 100 lines of container logs + ck-logs -f # Follow container logs + ck-logs -n 500 # Show last 500 lines + ck-logs --cmake # Show CMake configuration + ck-logs --build # Show build log + +EOF +} + +# Parse arguments +follow=false +tail_lines=100 +show_cmake=false +show_build=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -f|--follow) + follow=true + shift + ;; + -n|--tail) + tail_lines="$2" + shift 2 + ;; + --cmake) + show_cmake=true + shift + ;; + --build) + show_build=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Check if container exists +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 1 +fi + +# Show CMake log +if [ "$show_cmake" = true ]; then + echo "CMake Configuration Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/CMakeCache.txt 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build + echo 'GPU_TARGETS:' \$(grep 'GPU_TARGETS:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_BUILD_TYPE:' \$(grep 'CMAKE_BUILD_TYPE:' CMakeCache.txt | cut -d'=' -f2) + echo 'CMAKE_CXX_COMPILER:' \$(grep 'CMAKE_CXX_COMPILER:' CMakeCache.txt | cut -d'=' -f2) + echo 'BUILD_TESTING:' \$(grep 'BUILD_TESTING:' CMakeCache.txt | cut -d'=' -f2) + " + else + echo "CMake not configured yet" + fi + exit 0 +fi + +# Show build log (last build output) +if [ "$show_build" = true ]; then + echo "Last Build Log:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/.ninja_log 2>/dev/null; then + docker exec "${CONTAINER_NAME}" bash -c "tail -50 /workspace/build/.ninja_log" + else + echo "No build log found" + fi + exit 0 +fi + +# Show container logs +echo "Container Logs (${CONTAINER_NAME}):" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if [ "$follow" = true ]; then + docker logs -f "${CONTAINER_NAME}" +else + docker logs --tail "${tail_lines}" "${CONTAINER_NAME}" +fi diff --git a/script/tools/ck-shell b/script/tools/ck-shell new file mode 100755 index 0000000000..785c9f4d68 --- /dev/null +++ b/script/tools/ck-shell @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Shell - Open interactive shell in Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Shell - Open interactive shell in Docker container + +Usage: ck-shell [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + -c Execute command instead of interactive shell + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-shell # Open interactive shell + ck-shell my_container # Open shell in specific container + ck-shell -c "rocm-smi" # Execute single command + ck-shell -c "cd build && ls bin" # Execute command in build directory + +EOF +} + +# Parse arguments +command="" + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + -c) + command="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Execute command or open shell +if [ -n "$command" ]; then + echo "Executing: ${command}" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "${command}" +else + echo "Opening shell in '${CONTAINER_NAME}' (type 'exit' to leave)..." + docker exec -it "${CONTAINER_NAME}" bash +fi diff --git a/script/tools/ck-start b/script/tools/ck-start new file mode 100755 index 0000000000..f15477492a --- /dev/null +++ b/script/tools/ck-start @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Start - Start Docker container for Composable Kernel testing + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Start - Start Docker container for Composable Kernel testing + +Usage: ck-start [options] [container_name] + +Options: + -h, --help Show this help message + --image Specify Docker image (overrides CK_DOCKER_IMAGE) + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + CK_DOCKER_IMAGE - Override Docker image (default: rocm/composable_kernel:ck_ub24.04_rocm7.0.1) + +Examples: + ck-start # Start container with default name + ck-start my_ck_container # Start container with custom name + ck-start --image rocm/composable_kernel:latest + +EOF +} + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --image) + export CK_DOCKER_IMAGE="$2" + shift 2 + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Get Docker image +DOCKER_IMAGE=$(get_docker_image) + +# Check if container exists and is running +if container_exists "${CONTAINER_NAME}"; then + if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is already running" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + else + echo "Starting existing container '${CONTAINER_NAME}'..." + docker start "${CONTAINER_NAME}" + echo "Container started" + docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + exit 0 + fi +fi + +# Create new container +echo "Creating new Docker container '${CONTAINER_NAME}'..." +echo "Docker image: ${DOCKER_IMAGE}" +echo "Project root: ${PROJECT_ROOT}" +echo "" + +docker run -d \ + --name "${CONTAINER_NAME}" \ + --device=/dev/kfd --device=/dev/dri \ + --security-opt seccomp=unconfined \ + --group-add video \ + -v "${PROJECT_ROOT}":/workspace \ + -w /workspace \ + "${DOCKER_IMAGE}" \ + tail -f /dev/null + +echo "" +echo "Container '${CONTAINER_NAME}' started successfully" +docker exec "${CONTAINER_NAME}" bash -c "echo 'Working directory:' && pwd" + +# Show GPU info +echo "" +echo "GPU Information:" +docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -5 || echo 'No GPU detected'" diff --git a/script/tools/ck-status b/script/tools/ck-status new file mode 100755 index 0000000000..fea9de8c36 --- /dev/null +++ b/script/tools/ck-status @@ -0,0 +1,153 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Status - Check container status and information + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Status - Check container status and information + +Usage: ck-status [options] [container_name] + +Options: + -h, --help Show this help message + --name Specify container name + --all Show all CK containers + -v, --verbose Show detailed information + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-status # Check default container status + ck-status my_container # Check specific container + ck-status --all # Show all CK containers + ck-status -v # Show detailed information + +EOF +} + +# Parse arguments +show_all=false +verbose=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --all) + show_all=true + shift + ;; + -v|--verbose) + verbose=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +DOCKER_IMAGE=$(get_docker_image) + +# Show all containers +if [ "$show_all" = true ]; then + echo "Composable Kernel Docker Containers:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format "table {{.Names}}\t{{.Status}}\t{{.CreatedAt}}" 2>/dev/null || echo "") + + if [ -z "$containers" ] || [ "$containers" = "NAMES STATUS CREATED AT" ]; then + echo "No CK containers found for user '${username}'" + else + echo "$containers" + fi + exit 0 +fi + +# Check specific container status +echo "Container: ${CONTAINER_NAME}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +if container_is_running "${CONTAINER_NAME}"; then + echo "Status: RUNNING ✓" + echo "" + docker ps --filter "name=^${CONTAINER_NAME}$" --format "table {{.Names}}\t{{.Status}}\t{{.Image}}" + + if [ "$verbose" = true ]; then + echo "" + echo "Container Details:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker inspect "${CONTAINER_NAME}" --format ' +Image: {{.Config.Image}} +Created: {{.Created}} +Platform: {{.Platform}} +Mounts: {{range .Mounts}} + - {{.Source}} -> {{.Destination}}{{end}} +' + fi + + echo "" + echo "GPU Information:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + docker exec "${CONTAINER_NAME}" bash -c "rocm-smi --showproductname 2>/dev/null | head -10 || echo 'No GPU detected'" + + if [ "$verbose" = true ]; then + echo "" + echo "GPU Target:" + gpu_target=$(detect_gpu_target "${CONTAINER_NAME}") + echo " ${gpu_target}" + + echo "" + echo "Build Status:" + echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + if docker exec "${CONTAINER_NAME}" test -d /workspace/build 2>/dev/null; then + if docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo " CMake configured ✓" + echo " Build directory: /workspace/build" + + # Count built test binaries + bin_count=$(docker exec "${CONTAINER_NAME}" bash -c "ls -1 /workspace/build/bin 2>/dev/null | wc -l" || echo "0") + echo " Test binaries: ${bin_count}" + else + echo " CMake not configured" + fi + else + echo " Build directory not found" + fi + fi + +elif container_exists "${CONTAINER_NAME}"; then + echo "Status: STOPPED" + echo "" + echo "Start with: ck-start" +else + echo "Status: DOES NOT EXIST" + echo "" + echo "Create with: ck-start" +fi diff --git a/script/tools/ck-stop b/script/tools/ck-stop new file mode 100755 index 0000000000..b793f47408 --- /dev/null +++ b/script/tools/ck-stop @@ -0,0 +1,141 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Stop - Stop and remove Docker container + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Stop - Stop and remove Docker container + +Usage: ck-stop [options] [container_name] + +Options: + -h, --help Show this help message + -f, --force Force stop without confirmation + --all Stop all CK containers for this user + +Arguments: + container_name Optional container name (default: ck__) + +Environment: + CK_CONTAINER_NAME - Override default container name + +Examples: + ck-stop # Stop default container + ck-stop my_ck_container # Stop specific container + ck-stop --all # Stop all user's CK containers + ck-stop --force # Stop without confirmation + +EOF +} + +# Parse arguments +force=false +stop_all=false + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + -f|--force) + force=true + shift + ;; + --all) + stop_all=true + shift + ;; + *) + CONTAINER_NAME="$1" + shift + ;; + esac +done + +# Function to stop a single container +stop_container() { + local name="$1" + + if ! container_exists "${name}"; then + echo "Container '${name}' does not exist" + return 1 + fi + + echo "Stopping and removing container '${name}'..." + docker stop "${name}" 2>/dev/null || true + docker rm "${name}" 2>/dev/null || true + echo "Container '${name}' stopped and removed" +} + +# Stop all user containers +if [ "$stop_all" = true ]; then + username=$(get_username) + containers=$(docker ps -a --filter "name=ck_${username}_" --format '{{.Names}}') + + if [ -z "$containers" ]; then + echo "No CK containers found for user '${username}'" + exit 0 + fi + + echo "Found CK containers for user '${username}':" + echo "$containers" + echo "" + + if [ "$force" = false ]; then + read -p "Stop and remove all these containers? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi + fi + + echo "" + while IFS= read -r container; do + stop_container "$container" + done <<< "$containers" + + echo "" + echo "All containers stopped and removed" + exit 0 +fi + +# Stop single container +if ! container_exists "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' does not exist" + exit 0 +fi + +# Show container info +if container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' is currently running" +else + echo "Container '${CONTAINER_NAME}' exists but is stopped" +fi + +# Confirm if not forced +if [ "$force" = false ]; then + read -p "Stop and remove container '${CONTAINER_NAME}'? (y/N) " -n 1 -r + echo "" + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + echo "Cancelled" + exit 0 + fi +fi + +stop_container "${CONTAINER_NAME}" diff --git a/script/tools/ck-test b/script/tools/ck-test new file mode 100755 index 0000000000..712f904596 --- /dev/null +++ b/script/tools/ck-test @@ -0,0 +1,166 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# CK Test - Build and test Composable Kernel in Docker + +set -e +set -o pipefail + +# Find script directory and load common utilities +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/common.sh" + +# Initialize configuration +PROJECT_ROOT=$(get_project_root "${SCRIPT_DIR}") +CONTAINER_NAME=$(get_container_name "${PROJECT_ROOT}") + +# Help message +show_help() { + cat << EOF +CK Test - Build and test Composable Kernel in Docker + +Usage: ck-test [options] [test_options] + +Options: + -h, --help Show this help message + --name Specify container name + --reconfigure Reconfigure CMake before building + --no-build Skip building, run test directly + +Arguments: + test_name Name of test executable (required) + test_options Additional options passed to test (e.g., --gtest_filter=*) + +Environment: + CK_CONTAINER_NAME - Override default container name + GPU_TARGET - Override GPU target detection (e.g., gfx950, gfx942) + +Examples: + ck-test test_amdgcn_mma + ck-test test_amdgcn_mma --gtest_filter=*Fp16* + ck-test --name my_container test_amdgcn_mma + ck-test --reconfigure test_amdgcn_mma + +EOF +} + +# Parse arguments +test_name="" +reconfigure=false +no_build=false +test_options=() + +while [[ $# -gt 0 ]]; do + case $1 in + -h|--help) + show_help + exit 0 + ;; + --name) + CONTAINER_NAME="$2" + shift 2 + ;; + --reconfigure) + reconfigure=true + shift + ;; + --no-build) + no_build=true + shift + ;; + --gtest_*|--help) + test_options+=("$1") + shift + ;; + *) + if [ -z "$test_name" ]; then + test_name="$1" + else + test_options+=("$1") + fi + shift + ;; + esac +done + +# Validate test name +if [ -z "$test_name" ]; then + echo "Error: test_name required" + echo "" + show_help + exit 1 +fi + +# Ensure container is running +if ! container_is_running "${CONTAINER_NAME}"; then + echo "Container '${CONTAINER_NAME}' not running. Starting..." + "${SCRIPT_DIR}/ck-start" "${CONTAINER_NAME}" + echo "" +fi + +# Configure CMake if needed or requested +if [ "$reconfigure" = true ] || ! docker exec "${CONTAINER_NAME}" test -f /workspace/build/build.ninja 2>/dev/null; then + echo "Detecting GPU target..." + GPU_TARGET_DETECTED=$(detect_gpu_target "${CONTAINER_NAME}") + + if [ "$reconfigure" = true ]; then + echo "Reconfiguring CMake from scratch for GPU target: ${GPU_TARGET_DETECTED}" + else + echo "Configuring build with CMake for GPU target: ${GPU_TARGET_DETECTED}" + fi + + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace || exit 1 + rm -rf /workspace/build + mkdir /workspace/build + cd /workspace/build || exit 1 + cmake .. -GNinja \ + -DGPU_TARGETS=${GPU_TARGET_DETECTED} \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ + -DBUILD_TESTING=ON 2>&1 | tail -30 + " + echo "" +fi + +# Build test if needed (unless --no-build is specified) +if [ "$no_build" = false ]; then + if ! docker exec "${CONTAINER_NAME}" test -f "/workspace/build/bin/${test_name}" 2>/dev/null; then + echo "Building ${test_name}..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + else + echo "Test executable found, rebuilding to ensure latest version..." + docker exec "${CONTAINER_NAME}" bash -c " + cd /workspace/build || exit 1 + ninja ${test_name} 2>&1 + " + echo "" + fi +fi + +# Run test +echo "Running: ${test_name} ${test_options[*]}" +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" + +# Build the command with proper quoting +cmd="cd /workspace/build && ./bin/${test_name}" +for opt in "${test_options[@]}"; do + cmd="${cmd} $(printf '%q' "$opt")" +done + +docker exec "${CONTAINER_NAME}" bash -c "${cmd}" +exit_code=$? + +echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" +if [ $exit_code -eq 0 ]; then + echo "Test completed successfully" +else + echo "Test failed with exit code: ${exit_code}" +fi + +exit $exit_code From b8751e505d04cbb866bca769d408e9da8cb64c42 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 27 Jan 2026 00:57:42 +0530 Subject: [PATCH 07/28] feat: Add Interwave scheduler for aquant memory pipeline (#3540) * WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * test: add unit test for aquant memory pipeline * WIP: host level interwave pipeline compiles * WIP: interwave implementation computes correct GEMM result when no aquant * WIP: quantization works for subset of problem shapes * WIP: quantization works for subset of problem shapes * WIP: interwave memory pipeline passes local test * feat: Add interwave pipeline implementation for memory pipline in aquant * fix: compilation error on gfx950 * chore: remove debug statements from the code * test: resolve merge conflict * test: remove non rcr unit tests from test suite --- .../gemm_aquant_quantgrouped.cpp | 2 +- .../38_block_scale_gemm/gemm_utils.hpp | 23 ++ .../run_gemm_quant_example.inc | 180 ++++++++++++- .../block_universal_gemm_as_aquant_bs_cr.hpp | 223 +++++++++++++++- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 2 +- test/ck_tile/gemm_block_scale/CMakeLists.txt | 30 ++- ...gemm_quant_aquant_mem_decode_interwave.cpp | 41 +++ ...gemm_quant_aquant_mem_decode_intrawave.cpp | 41 +++ ...emm_quant_aquant_mem_prefill_interwave.cpp | 41 +++ .../test_gemm_quant_aquant_prefill.cpp | 6 +- .../test_gemm_quant_fixtures.hpp | 249 ++++++++++++++++++ 11 files changed, 829 insertions(+), 9 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index ad1a4e0d10..e037be5a18 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantDecode; +using GemmConfig = GemmConfigQuantDecodeInterwave; // GemmConfigQuantPrefill is also supported for aquant grouped quantization // template diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index a95ca4862c..37117eaa0f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -93,6 +93,27 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigQuantDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -229,6 +250,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = ck_tile::get_k_warp_tile(); + + // static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 912527c929..ed1709a9ae 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -650,7 +650,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(1.0f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) @@ -659,6 +659,184 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } + else if(init_method == 3) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + ck_tile::FillConstant{static_cast(0x38)}(a_m_k); + ck_tile::FillConstant{static_cast(0x22)}(b_k_n); + ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + else + { + ck_tile::FillConstant{static_cast(0x22)}(a_m_k); + ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0x38)}(b_k_n); + + if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + { + ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); + } + } + } + else if(init_method == 4) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + } + ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } + else if(init_method == 5) + { + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); + } + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + } + else + { + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); + } + // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) + for(ck_tile::index_t row = 0; + row < static_cast(aq_tensor_ptr->get_length(0)); + ++row) + { + for(ck_tile::index_t col = 0; + col < static_cast(aq_tensor_ptr->get_length(1)); + ++col) + { + (*aq_tensor_ptr)(row, col) = static_cast(col + 1); + } + } + // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; + ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); + } + else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) + { + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + a_m_k); + ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( + b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); + } + } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b52..9d19e902e5 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -274,7 +274,9 @@ struct AQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { CWarpTensor c_warp_tensor; + // for every column in AQ static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // for every warp corresponding to a quantization scale static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; @@ -322,6 +324,214 @@ struct AQuantBlockUniversalGemmAsBsCr } }; + template + struct BlockGemmImpl + { + static constexpr index_t KPerThread = GemmTraits::KPerThread; + static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters; + + static constexpr index_t KPerInnerLoop = + ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; + + static constexpr auto ALdsTileDistr = + make_static_tile_distribution(MakeABlockDistributionEncode()); + static constexpr auto BLdsTileDistr = + make_static_tile_distribution(MakeBBlockDistributionEncode()); + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_warp_tile_; + BLdsTile b_warp_tile_; + + template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant = {}, + bool_constant = {}) + { + constexpr auto a_lds_load_distr = [&]() { + if constexpr(ALoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeABlockDistributionEncode()), + ADataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeABlockDistributionEncode()); + }(); + constexpr auto b_lds_load_distr = [&]() { + if constexpr(BLoadTranspose) + return make_static_tile_distribution(typename InputTileDistributionTraits< + decltype(MakeBBlockDistributionEncode()), + BDataType>::TransposedDstrEncode{}); + else + return make_static_tile_distribution(MakeBBlockDistributionEncode()); + }(); + constexpr auto a_lds_shape = []() { + if constexpr(ALoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto b_lds_shape = []() { + if constexpr(BLoadTranspose) + return make_tuple(number{}, number{}); + else + return make_tuple(number{}, number{}); + }(); + constexpr auto k_idx_offset = KIdx * KPerInnerLoop; + constexpr auto a_offset = + ALoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + constexpr auto b_offset = + BLoadTranspose ? multi_index<2>{k_idx_offset, 0} : multi_index<2>{0, k_idx_offset}; + + auto a_lds_gemm_window = make_tile_window( + a_block_window.get_bottom_tensor_view(), a_lds_shape, a_offset, a_lds_load_distr); + auto b_lds_gemm_window = make_tile_window( + b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); + + load_int4_tile( + a_warp_tile_, a_lds_gemm_window); + load_int4_tile( + b_warp_tile_, b_lds_gemm_window); + } + + // C += A * B with quantization support + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + AQBlockTensor& aq_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + static_assert(std::is_same_v, + "The CDataType as defined in traits should be the same as corresponding " + "C block tensor data type!"); + constexpr auto warp_size = get_warp_size(); + + // Track which KRepeat chunk is currently loaded + index_t current_k_repeat_loaded = -1; + + // Restructured loop: M → N → QScale → KIterPerQScale + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Iterate over quantization groups + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + CWarpTensor c_warp_tensor; + + // Accumulate K iterations for this quantization group + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + // Map quantization indices to global K iteration + constexpr auto kIterGlobal = + kQScale * Traits::KIterPerQScale + kIterInQScale; + + // Map to KRepeat chunk and KInnerLoopIter offset + constexpr auto kRepeatIdx = kIterGlobal / KInnerLoopIter; + constexpr auto kInnerIdx = kIterGlobal % KInnerLoopIter; + + // Prefetch new chunk if needed + if constexpr(kInnerIdx == 0) + { + if(current_k_repeat_loaded != kRepeatIdx) + { + LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + __builtin_amdgcn_sched_barrier(0); + + if constexpr(kRepeatIdx != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + + current_k_repeat_loaded = kRepeatIdx; + } + } + + // Load A warp tensor + AWarpTensor a_warp_tensor; + a_warp_tensor.get_thread_buffer() = + a_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + a_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, a_warp_y_lengths)); + + // Load B warp tensor + BWarpTensor b_warp_tensor; + b_warp_tensor.get_thread_buffer() = + b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, + b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Synchronization barrier at the end of last iteration + if constexpr(kQScale == Traits::QScalesPerBlockRow - 1 && + kIterInQScale == Traits::KIterPerQScale - 1 && + mIter.value == MIterPerWarp - 1 && + nIter.value == NIterPerWarp - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // Accumulate: first iteration initializes, rest accumulate + if constexpr(kIterInQScale == 0) + { + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Set priority for scheduling + if constexpr(kInnerIdx == 0 && mIter.value == 0 && nIter.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // Apply quantization scale after accumulating all K iterations for this + // group + constexpr auto tbuf_offset = + number{}, + c_warp_y_index_zeros)) / + CBlockTensor::PackedSize>{}; + + AQPickerCommon aq_picker( + aq_block_tensor); + + static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}( + [&](auto c_row) { + float scale_reg_f = aq_picker.template pick(); + c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] += + (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + }; + public: CK_TILE_DEVICE static constexpr auto MakeCBlockTile() { @@ -329,7 +539,8 @@ struct AQuantBlockUniversalGemmAsBsCr MakeCBlockTile(); } - template @@ -338,7 +549,15 @@ struct AQuantBlockUniversalGemmAsBsCr bool_constant a_load_tr = {}, bool_constant b_load_tr = {}) { - block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + if constexpr(Scheduler == GemmPipelineScheduler::Interwave) + { + block_gemm_impl_.template LocalPrefetch( + a_block_window, b_block_window, a_load_tr, b_load_tr); + } + else + { + block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window, a_load_tr, b_load_tr); + } } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 650cd947f7..b87c12c14a 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -499,7 +499,7 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return PipelineImpl{} .template operator()( a_dram_block_window_tmp, - [](const OverrideADataType& a) { return a; }, + [](const BDataType& a) { return a; }, b_dram_block_window_tmp, [](const BDataType& b) { return b; }, aq_dram_block_window_tmp, diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 5749a8d3b2..30c4eb11f9 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -11,7 +11,24 @@ list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") # Typed Test Suite for GEMM Quantization - split into multiple files to reduce compile time - # AQuant tests - split into 6 files + # AQuant tests - split into 10 files + + # AQuant Memory Pipeline tests + add_gtest_executable(test_tile_gemm_quant_aquant_mem_prefill_interwave + test_gemm_quant_aquant_mem_prefill_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_prefill_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_intrawave + test_gemm_quant_aquant_mem_decode_intrawave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_intrawave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_aquant_mem_decode_interwave + test_gemm_quant_aquant_mem_decode_interwave.cpp + ) + target_compile_options(test_tile_gemm_quant_aquant_mem_decode_interwave PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_aquant_base_rcr test_gemm_quant_aquant_base_rcr.cpp ) @@ -150,10 +167,21 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_tensor PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # Target to build only AQuant memory pipeline tests + add_custom_target(test_tile_gemm_aquant_mem_all) + add_dependencies(test_tile_gemm_aquant_mem_all + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave + ) + # Umbrella target to build all gemm quant tests add_custom_target(test_tile_gemm_quant_all) add_dependencies(test_tile_gemm_quant_all # AQuant tests + test_tile_gemm_quant_aquant_mem_prefill_interwave + test_tile_gemm_quant_aquant_mem_decode_intrawave + test_tile_gemm_quant_aquant_mem_decode_interwave test_tile_gemm_quant_aquant_base_rcr test_tile_gemm_quant_aquant_base_rrr_crr test_tile_gemm_quant_aquant_base_ccr diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp new file mode 100644 index 0000000000..a7ab4120a1 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeInterwaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp new file mode 100644 index 0000000000..483138d711 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_decode_intrawave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Decode Intrawave Configuration +// Tuple format: +// clang-format off +using AQuantMemDecodeIntrawaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Decode Intrawave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemDecodeIntrawaveTest) +{ + this->run_test_with_validation(16, 64, 512); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp new file mode 100644 index 0000000000..7e851d9bd3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_mem_prefill_interwave.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Mem Prefill Interwave Configuration +// Tuple format: +// clang-format off +using AQuantMemPrefillInterwaveTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Mem Prefill Interwave +TYPED_TEST_SUITE(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuantMem, AQuantMemPrefillInterwaveTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp index 133c11860a..911af678df 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -25,9 +25,9 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using AQuantPrefillTypes = ::testing::Types< // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 79c86935ef..9652dd449d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -69,6 +69,38 @@ struct GemmConfigPrefill : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefillIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigPrefillInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigDecodeIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigDecodeInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + struct GemmConfigMxFp4 : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; @@ -374,6 +406,223 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase +class TestCkTileGemmAQuantMem + : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::AQLayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + using typename Base::QuantGroupSize; + static constexpr auto QuantType = Base::QuantType; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + // AQuant-specific data generation + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = + ck_tile::get_default_stride(M, K, 0, this->is_row_major(ALayout{})); + const ck_tile::index_t stride_B = + ck_tile::get_default_stride(K, N, 0, this->is_row_major(BLayout{})); + const ck_tile::index_t stride_C = + ck_tile::get_default_stride(M, N, 0, this->is_row_major(CLayout{})); + // AQuant uses grouped quantization for A matrix + const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize::kK); + // AQLayout is parameterized in the test tuple (can be RowMajor or ColumnMajor for AQuant) + const ck_tile::index_t stride_AQ = + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(AQLayout{})); + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + // AQLayout is independently specified for each test case + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(AQLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + // Initialize data with random values + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f}(a_m_k); + } + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f}(aq_m_aqk); + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + // Copy to device + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor temp = a_m_k; + ck_tile::permute_vectors_i4x4_b(temp); + a_m_k_dev_buf.ToDevice(temp.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } + // aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + if constexpr(Base::GemmConfig::PreshuffleQuant) + { + ck_tile::HostTensor aq_shuffle_host = + ck_tile::shuffle_aq(&aq_m_aqk, Base::GemmConfig::K_Tile / QuantGroupSize::kK); + aq_m_aqk_dev_buf.ToDevice(aq_shuffle_host.data()); + } + else + { + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales) + nullptr, // bq_ptr (not used for AQuant) + 1, // k_batch + M, + N, + K, // M, N, K + AQK, // QK_A + 0, // QK_B (not used for AQuant) + stride_A, + stride_B, + stride_C, + stride_AQ, + 0 // strides + }; + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + // Run reference AQuant implementation + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N + << ", K=" << K; + if(!pass) + { + std::cout << "AQuantGrouped - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // AQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = CodegenGemmTraits::TransposeC; + using PipelineProblem = ck_tile::GemmAQuantPipelineProblem; + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrMem; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c>>; + using Kernel = ck_tile::QuantGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for AQuant kernel"); + } + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + // BQuant-specific test fixture template class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase> From 8942a19d5efafa151e0f894599bc625117d7aa76 Mon Sep 17 00:00:00 2001 From: yinglu Date: Tue, 27 Jan 2026 03:38:45 +0800 Subject: [PATCH 08/28] ck: add CK_USE_GFX950 macro (#3636) --- CMakeLists.txt | 5 +++++ include/ck/config.h.in | 7 ------- .../device_grouped_conv_bwd_data_xdl_instance.hpp | 2 +- .../device_grouped_conv_fwd_xdl_merged_groups_instance.hpp | 2 +- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f1bdf8689..356491d9c1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -259,6 +259,11 @@ if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL) + message(STATUS "Enabling XDL FP8 gemms on gfx950") + add_definitions(-DCK_USE_GFX950) + set(CK_USE_GFX950 "ON") +endif() # new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA set(CK_TILE_USE_WMMA 0) diff --git a/include/ck/config.h.in b/include/ck/config.h.in index f5421e7d5e..306a6c2ff1 100644 --- a/include/ck/config.h.in +++ b/include/ck/config.h.in @@ -55,9 +55,6 @@ #ifndef CK_ENABLE_FP32 #define CK_ENABLE_FP32 "ON" #endif -#ifndef CK_ENABLE_TF32 -#define CK_ENABLE_TF32 "ON" -#endif #ifndef CK_ENABLE_FP64 #define CK_ENABLE_FP64 "ON" #endif @@ -88,10 +85,6 @@ #cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ #endif -#ifndef CK_ENABLE_TF32 -#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@ -#endif - #ifndef CK_ENABLE_FP64 #cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ #endif diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 745f8cbd32..970bcb0439 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -376,7 +376,7 @@ using device_grouped_conv_bwd_data_xdl_f32_optimized_loads_instances = // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp index 18abcb1613..3b7ce0df3a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp @@ -147,7 +147,7 @@ using device_grouped_conv_fwd_xdl_merged_groups_f32_instances = std::tuple< // clang-format on >; -#if defined(__gfx950__) +#if defined(CK_USE_GFX950) constexpr auto _k_per_block = 32; #else constexpr auto _k_per_block = 16; From bd5fec81afdb6df7f4637128a3ba86dbfd6bcca1 Mon Sep 17 00:00:00 2001 From: Thrupti Raj Lakshmana Gowda Date: Mon, 26 Jan 2026 13:56:06 -0600 Subject: [PATCH 09/28] Removing [4,64,16] warp tile from Tile Engine (#3643) --- tile_engine/ops/gemm/gemm_validation_utils.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_validation_utils.py b/tile_engine/ops/gemm/gemm_validation_utils.py index cae6123307..1af45f8e90 100644 --- a/tile_engine/ops/gemm/gemm_validation_utils.py +++ b/tile_engine/ops/gemm/gemm_validation_utils.py @@ -128,7 +128,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -136,7 +135,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]], @@ -148,7 +146,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -156,7 +153,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]], @@ -169,7 +165,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "bf16_bf16_bf16": [ @@ -177,7 +172,6 @@ GEMM_WARP_TILE_SUPPORTED_COMBINATIONS = { [16, 16, 16], [32, 32, 16], [16, 16, 32], - [4, 64, 16], [64, 4, 16], ], "fp8_fp8_fp16": [ From 2e49b6b2f79d5ab0fe2fca79812affd44de94db7 Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Mon, 26 Jan 2026 21:57:09 +0100 Subject: [PATCH 10/28] Padding support for wave transfer (#3537) * Add padding support with transpose Also move check before writing storing is_src_valid during reading * Add/modify instances to use wave transfer for gemm universal Condition is changed so now the vectorsize of vmem reading and lds writing must be equal to 8 in order to use the wave transfer * Fix clang format * Modify example * Fix bwd data * Add restriction for wave transfer with padding and transpose Add test case which shows this limitation * Fix validity checks 8 bit types * Add validity check gemm_bias_add_reduce * Add validity check grouped gemm tile loop * Fix validity checks new flavours * Minor fixes * Fix clang format --- example/01_gemm/gemm_wmma_fp16_v3.cpp | 10 +-- ...ead_group_tensor_slice_transfer_global.hpp | 69 +++++++++++++--- ...ontraction_multiple_d_wmma_cshuffle_v3.hpp | 20 +++++ ...tched_gemm_multiple_d_wmma_cshuffle_v3.hpp | 20 +++++ ...e_batched_gemm_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ ...e_batched_gemm_wmma_cshuffle_v3_common.hpp | 20 +++++ ..._gemm_bias_add_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ ..._multiple_d_layernorm_wmma_cshuffle_v3.hpp | 22 ++++++ .../device_gemm_reduce_wmma_cshuffle_v3.hpp | 22 ++++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 20 +++++ .../impl/device_gemm_wmma_cshuffle_v3r1.hpp | 20 +++++ ...v_bwd_data_multiple_d_wmma_cshuffle_v3.hpp | 8 +- ..._multiple_d_wmma_cshuffle_tile_loop_v3.hpp | 23 ++++++ ...e_grouped_gemm_wmma_splitk_cshuffle_v3.hpp | 23 +++++- .../grid/gridwise_ab_transfer_wave_tiles.hpp | 4 - .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 79 ++++++++++++++++--- ...wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp | 3 +- ...wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp | 5 +- ...wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp | 5 +- ...mm_wmma_universal_f16_f16_f16_km_kn_mn.hpp | 2 +- ...mm_wmma_universal_f16_f16_f16_km_nk_mn.hpp | 4 +- ...mm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp | 4 +- .../test_gemm_universal_ut_cases_fp16.inc | 8 +- 23 files changed, 385 insertions(+), 50 deletions(-) diff --git a/example/01_gemm/gemm_wmma_fp16_v3.cpp b/example/01_gemm/gemm_wmma_fp16_v3.cpp index 5b10edd681..3b3b0fec16 100644 --- a/example/01_gemm/gemm_wmma_fp16_v3.cpp +++ b/example/01_gemm/gemm_wmma_fp16_v3.cpp @@ -19,22 +19,22 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; // clang-format off using DeviceGemmV2Instance = ck::tensor_operation::device::DeviceGemm_Wmma_CShuffleV3< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, - PassThrough, PassThrough, PassThrough, GemmDefault, + PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, + S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, + 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, - S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, - 1, 1, 8, 1, + 1, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1>; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp index 701c786c86..1c322fe4a7 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_global.hpp @@ -160,6 +160,7 @@ struct ThreadGroupTransferGlobal // check if src element is valid const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + oob_thread_scratch_.template SetAsType(vgpr_data_idx_seq, is_src_valid); // Vector length of elementwise operation constexpr auto get_elem_op_vec_len = []() { @@ -195,14 +196,12 @@ struct ThreadGroupTransferGlobal using dst_vector_type = vector_type_maker_t; using dst_vector_t = typename dst_vector_type::type; - using vector_t = typename vector_type_maker::type::type; - dst_vector_type op_r_v; // Load data from memory in src_vector first - src_vector_container src_vector = - src_vector_container{grid_buf.template Get( - src_coord_.GetOffset(), true)}; + auto index = is_src_valid || !DoTranspose ? src_coord_.GetOffset() : 0; + src_vector_container src_vector = src_vector_container{ + grid_buf.template Get(index, true)}; // apply the src elementwise op and convert to DstData under the hood if needed static_for<0, VectorSize / elem_op_vec_len, 1>{}([&](auto idx) { @@ -213,9 +212,8 @@ struct ThreadGroupTransferGlobal // store result in dvgpr_ (static array holding loaded data). // At this point data is already converted to DstData type and // the elementwise operation has been applied - dvgpr_.template SetAsType( - vgpr_data_idx_seq, - is_src_valid ? op_r_v.template AsType()[I0] : vector_t(0)); + src_dvgpr_.template SetAsType(vgpr_data_idx_seq, + op_r_v.template AsType()[I0]); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -248,6 +246,39 @@ struct ThreadGroupTransferGlobal container_reorder_given_new2old(src_access_lengths, src_dim_access_order); constexpr auto ordered_fwd_step = StepsPerIteration{}; + // OOB check + static_ford{}([&](auto ordered_src_access_idx) { + // calculate src data index and make sequence + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { ordered_idx(i) = ordered_src_access_idx[i]; }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order); + }(); + + // make sequence to access vgpr data. Add zero as last element of src_data_idx_seq + constexpr auto vgpr_data_idx_seq = generate_sequence_v2( + [&](auto i) { + if constexpr(i.value < src_data_idx.Size()) + { + return Number{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + auto op_r = src_dvgpr_.template GetAsType(vgpr_data_idx_seq); + const bool is_src_valid = + oob_thread_scratch_.template GetAsType(vgpr_data_idx_seq); + auto op_r_v = is_src_valid ? op_r : dst_vector_t(0); + dst_dvgpr_.template SetAsType(vgpr_data_idx_seq, op_r_v); + }); + // make forward steps // forward step for each iteration just add 1 const auto dst_forward_steps = generate_tuple( @@ -352,7 +383,7 @@ struct ThreadGroupTransferGlobal dst_buf.template Set( dst_coord_.GetOffset(), true, - dvgpr_.template GetAsType(vgpr_data_idx_seq)); + dst_dvgpr_.template GetAsType(vgpr_data_idx_seq)); // For each dimension move fwd, bwd or don't move static_for<0, nDim, 1>{}([&](auto i) { @@ -389,6 +420,14 @@ struct ThreadGroupTransferGlobal return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); } + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto access_lengths_as_tuple = + container_push_back(sequence_to_tuple_of_number(NumberOfIterations{}), Number<1>{}); + + return make_naive_tensor_descriptor_packed(access_lengths_as_tuple); + } + static constexpr auto thread_data_scratch_desc_ = decltype(GetThreadScratchDataDescriptor()){}; using ThreadScratchData = StaticTensorTupleOfVectorBuffer; - ThreadScratchData dvgpr_; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + using OOBThreadScratch = StaticTensorTupleOfVectorBuffer; + + ThreadScratchData src_dvgpr_; + ThreadScratchData dst_dvgpr_; + OOBThreadScratch oob_thread_scratch_; SrcCoord src_coord_; DstCoord dst_coord_; const ElementwiseOperation element_op_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp index 47ef2e339d..b59357ffe9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle_v3.hpp @@ -833,6 +833,26 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle_V3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + // check vector access static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp index 126d107725..ae247f4e31 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp @@ -606,6 +606,26 @@ struct DeviceBatchedGemmMultiD_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp index 227a8aedd9..593a908498 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_wmma_cshuffle_v3.hpp @@ -588,6 +588,28 @@ struct DeviceBatchedGemmReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp index 59a820861c..fb1ca3127e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_common.hpp @@ -455,6 +455,26 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_Common return false; } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp index e8e3b69cb5..85ca16b293 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_wmma_cshuffle_v3.hpp @@ -471,6 +471,28 @@ struct DeviceGemmBiasAddReduce_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp index f0216c3f71..81f505b594 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp @@ -701,6 +701,28 @@ struct DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemmWelford::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemmWelford::Argument gemm_arg{ std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp index 317c4073df..28c9f2bddc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_wmma_cshuffle_v3.hpp @@ -456,6 +456,28 @@ struct DeviceGemmReduce_Wmma_CShuffleV3 : public DeviceGemmReduce<0, ReduceOpera return false; } + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityAWaveTransfer(arg.MRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && + !GridwiseGemm::CheckValidityBWaveTransfer(arg.NRaw_, arg.KRaw_)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + typename GridwiseGemm::Argument gemm_arg{std::array{arg.p_a_grid_}, std::array{arg.p_b_grid_}, std::array{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index e96ec58cba..c09befa717 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -421,6 +421,26 @@ struct DeviceGemm_Wmma_CShuffleV3_Common } } + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityAWaveTransfer(arg.M, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix A" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + + if(ck::is_gfx12_supported() && !GridwiseGemm::CheckValidityBWaveTransfer(arg.N, arg.K)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Wave Transfer not applicable for matrix B" << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp index e09c69d052..377f792979 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3r1.hpp @@ -393,6 +393,26 @@ struct DeviceGemm_Wmma_CShuffleV3R1 : public DeviceGemmV2R1(&arg)); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp index bbf62d5fbe..dfdfd53725 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle_v3.hpp @@ -450,8 +450,10 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 BlkGemmPipelineVer, AComputeType, BComputeType, - false, - false>; + false, // PermuteA + false, // PermuteB + false, // IsBPreShuffled + true>; // ForceThreadTileTransfer #define GridwiseGemmCTransposeTemplateParameters \ ALayout, BLayout, DsLayout, ELayout, Tuple, Tuple, AccDataType, \ @@ -467,7 +469,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3 ABlockLdsExtraM, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, \ CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CShuffleBlockTransferScalarPerVector, BlkGemmPipeSched, BlkGemmPipelineVer, BComputeType, \ - AComputeType, false, false + AComputeType, false, false, false, true using GridwiseGemmCTranspose = std::conditional_t placeholder_p_ds_grid{}; std::array stride_Ds; std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp index 39024d39e4..99a18e07fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -704,7 +704,28 @@ struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK::selected_wmma .wave_size; + __host__ __device__ static constexpr bool AWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && + ABlockTransferSrcScalarPerVector == 8 && ABlockTransferDstScalarPerVector_AK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && + !IsBPreShuffled; + } + + __host__ __device__ static constexpr bool BWaveTransferApplicable() + { + return !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && + BBlockTransferSrcScalarPerVector == 8 && BBlockTransferDstScalarPerVector_BK1 == 8 && + BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + } + // Limitations of the current implementation: // - no multiAB - // - GemmSpecialization Default with transpose #ifdef __gfx12__ - static constexpr bool IsAWaveTransferApplicable = - !ForceThreadTileTransfer && NumATensor == 1 && APackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && AK1Value == 8 && !IsBPreShuffled; + static constexpr bool IsAWaveTransferApplicable = AWaveTransferApplicable(); - static constexpr bool IsBWaveTransferApplicable = - !ForceThreadTileTransfer && NumBTensor == 1 && BPackedSize == 1 && - ((GemmSpec == tensor_operation::device::GemmSpecialization::Default && - !is_same_v) || - is_same_v) && - BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 && BK1Value == 8; + static constexpr bool IsBWaveTransferApplicable = BWaveTransferApplicable(); static constexpr bool IsWaveTileInterleavedFitting = (NPerBlock / NPerWmma / NRepeat) * (KPerBlock / KPack) >= (BlockSize / WaveSize); @@ -982,6 +986,55 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return de_grid_desc_mblock_mperblock_nblock_nperblock; } + // Conditions for Wave Transfer with transpose: + // - 16 bit type: K % 8 == 0 (4 subtiles of 8x8) + // - 8 bit type: K % 8 == 0 and M % 16 == 0 (2 subtiles of 8x16) + __host__ static constexpr bool CheckValidityAWaveTransfer(const index_t& M, const index_t& K) + { + if constexpr(AWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % ABlockTransferDstScalarPerVector_AK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + pass &= !(sizeof(ADataType_) == 1 && + !(M % (2 * ABlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + + __host__ static constexpr bool CheckValidityBWaveTransfer(const index_t& N, const index_t& K) + { + if constexpr(BWaveTransferApplicable() && + !(is_same::value)) + { + if(!(K % BBlockTransferDstScalarPerVector_BK1 == 0)) + { + return false; + } + bool pass = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + pass &= !(sizeof(BDataType_) == 1 && + !(N % (2 * BBlockTransferSrcScalarPerVector) == 0)); + }); + return pass; + } + else + { + return true; + } + } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} template __host__ static constexpr bool CheckValidity(const Argument& karg, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index d79fe9bfa3..d7b654a345 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -47,7 +47,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index e284cbbb83..7d7966c47f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -40,7 +40,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 6195d40f87..2f63199480 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -52,7 +52,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::t DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1> // clang-format on >; } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index e51bec3dfb..b50e37cf0a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -44,7 +44,7 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 66ba1e3830..4651068d86 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -40,9 +40,9 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tupl //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 8eccccf354..4dcbaccaa4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -41,7 +41,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1>, @@ -49,7 +49,7 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tupl DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 32, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1>, - DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, + DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, diff --git a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc index 25d95cda3d..01d7d5a5fd 100644 --- a/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc +++ b/test/gemm_universal/test_gemm_universal_ut_cases_fp16.inc @@ -125,7 +125,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_NK, MidLargeM) TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -139,7 +139,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -153,7 +153,7 @@ TYPED_TEST(TestGemmUniversal_FP16_MK_NK, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; @@ -169,7 +169,7 @@ TYPED_TEST(TestGemmUniversal_FP16_KM_KN, PaddK) TYPED_TEST(TestGemmUniversal_FP16_KM_NK, PaddK) { - std::vector Ms{127}; + std::vector Ms{127, 128}; constexpr int N = 512; constexpr int K = 437; From a213ce676bb6b72e177f73befa4d56b0ce60fbec Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 26 Jan 2026 13:44:36 -0800 Subject: [PATCH 11/28] Add python analysis scripts for Clang's time trace (#3644) This PR introduces a Python toolkit for analyzing Clang's `-ftime-trace` build performance data. This is the foundation for our systematic effort to reduce CK and CK-Tile build times (#3575). The toolkit provides fast parsing of trace JSON files into pandas DataFrames using orjson, with specialized functions for analyzing template instantiation costs and compilation phase breakdowns. It includes a core library (`trace_analysis/`), example scripts for quick analysis, a comprehensive README with usage documentation, and an interactive Jupyter notebook demonstration. Key features include memory-efficient DataFrame schemas with optimized dtypes, recursive hierarchical phase analysis, automatic metadata extraction (source file, compilation timing), and template instantiation filtering. The design supports both standalone scripts and interactive Jupyter notebook workflows. This single-file analysis capability lays the groundwork for future multi-file analysis across thousands of compilation units, enabling data-driven optimization and build time regression detection. --- script/analyze_build/README.md | 263 +++++++++++++ .../notebooks/file_analysis_example.ipynb | 247 ++++++++++++ script/analyze_build/requirements.txt | 18 + .../analyze_build/trace_analysis/__init__.py | 34 ++ .../trace_analysis/parse_file.py | 356 ++++++++++++++++++ .../trace_analysis/phase_breakdown.py | 354 +++++++++++++++++ .../trace_analysis/template_analysis.py | 80 ++++ .../trace_analysis/template_parser.py | 301 +++++++++++++++ 8 files changed, 1653 insertions(+) create mode 100644 script/analyze_build/README.md create mode 100644 script/analyze_build/notebooks/file_analysis_example.ipynb create mode 100644 script/analyze_build/requirements.txt create mode 100644 script/analyze_build/trace_analysis/__init__.py create mode 100644 script/analyze_build/trace_analysis/parse_file.py create mode 100644 script/analyze_build/trace_analysis/phase_breakdown.py create mode 100644 script/analyze_build/trace_analysis/template_analysis.py create mode 100644 script/analyze_build/trace_analysis/template_parser.py diff --git a/script/analyze_build/README.md b/script/analyze_build/README.md new file mode 100644 index 0000000000..7a88b98e77 --- /dev/null +++ b/script/analyze_build/README.md @@ -0,0 +1,263 @@ +# Build Trace Analysis + +Simple to use, fast python tools for analyzing Clang `-ftime-trace` build performance data. + +## Overview + +We're kicking off a systematic effort to dramatically reduce CK and CK-Tile build times, [#3575](https://github.com/ROCm/composable_kernel/issues/3575). A key part of this work is improving our C++ metaprogramming to reduce the burden on the compiler. + +In order to prioritize work and measure our progress, we need data on template instantiation. For single files, Clang's `-ftime-trace` build performance data is easy to analyze with the Perfetto UI. The problem we are solving here is how to analyze instantiation data across thousands of compilation units. + +The python code in this directory provides helper functions to quickly load JSON files into pandas DataFrames that can be used for analysis in Jupyter notebooks. + +## Directory Structure + +``` +script/analyze_build/ +├── trace_analysis/ # Core library +│ ├── __init__.py # Main exports +│ ├── parse_file.py # Fast parsing of JSON trace files +│ ├── template_analysis.py # Template instantiation analysis +│ ├── template_parser.py # Template name parsing utilities +│ └── phase_breakdown.py # Compilation phase breakdown +├── notebooks/ # Jupyter notebooks for analysis +│ └── file_analysis_example.ipynb # Template analysis example +├── requirements.txt # Python dependencies +└── README.md # This file +``` + +## Python Requirements + +See `requirements.txt` for the complete list of dependencies: +* **pandas** - DataFrame manipulation and analysis +* **orjson** - Fast JSON parsing for trace files +* **plotly** - Interactive visualizations (sunburst, treemap) +* **nbformat** - Jupyter notebook format support +* **ipykernel** - Kernel for running notebooks in VSCode/Jupyter +* **kaleido** - Static image export from Plotly charts +* **jupyter** - Full Jupyter environment + +## Quick Start + +### Setup + +1. Create a virtual environment (recommended): +```bash +cd script/analyze_build +python3 -m venv .venv +source .venv/bin/activate # On Windows: .venv\Scripts\activate +``` + +2. Install dependencies: +```bash +pip install -r requirements.txt +``` + +3. Install VSCode extensions if you want to run notebooks in VSCode: + * Jupyter + * Data Wrangler (interact with Pandas DataFrames) + +### Analyzing a Single File + +Use the `parse_file` function to load a `-ftime-trace` JSON file into a Pandas DataFrame: + +```python +from trace_analysis import parse_file + +# Parse the trace file +df = parse_file('path/to/trace.json') + +# View basic info +print(f"Total events: {len(df)}") +print(df.columns) + +# Analyze duration statistics +print(df['dur'].describe()) +``` + +### Extracting Compilation Metadata + +Get high-level metadata about the compilation: + +```python +from trace_analysis import get_metadata + +# Extract metadata from trace file +metadata = get_metadata('trace.json') + +print(f"Source file: {metadata['source_file']}") +print(f"Compilation time: {metadata['total_wall_time_s']:.2f}s") +print(f"Started: {metadata['wall_start_datetime']}") +print(f"Ended: {metadata['wall_end_datetime']}") +``` + +The metadata includes: +- `source_file`: Main .cpp/.c file being compiled +- `time_granularity`: Time unit used ("microseconds") +- `beginning_of_time`: Epoch timestamp in microseconds +- `wall_start_time`: Wall clock start (microseconds since epoch) +- `wall_end_time`: Wall clock end (microseconds since epoch) +- `wall_start_datetime`: Human-readable start time +- `wall_end_datetime`: Human-readable end time +- `total_wall_time_us`: Total compilation time in microseconds +- `total_wall_time_s`: Total compilation time in seconds + +### Template Instantiation Analysis + +The module includes specialized functions for analyzing C++ template instantiation costs: + +```python +from trace_analysis import ( + parse_file, + get_template_instantiation_events, + get_phase_breakdown, +) + +df = parse_file('trace.json') + +# Get all template instantiation events with parsed template information +template_events = get_template_instantiation_events(df) + +# The returned DataFrame includes parsed columns: +# - namespace: Top-level namespace (e.g., 'std', 'ck') +# - template_name: Template name without parameters +# - full_qualified_name: Full namespace::template_name +# - param_count: Number of template parameters +# - is_ck_type: Boolean indicating CK library types +# - is_nested: Boolean indicating nested templates + +# Find slowest template instantiations +top_templates = template_events.nlargest(20, 'dur') +print(top_templates[['template_name', 'namespace', 'param_count', 'dur']]) + +# Analyze by namespace +namespace_summary = template_events.groupby('namespace').agg({ + 'dur': ['count', 'sum', 'mean'] +}) +print(namespace_summary) +``` + +### Compilation Phase Breakdown + +Analyze how compilation time is distributed across different phases: + +```python +from trace_analysis import get_phase_breakdown, PhaseBreakdown + +df = parse_file('trace.json') + +# Get hierarchical phase breakdown +breakdown = get_phase_breakdown(df) + +# Display in Jupyter (automatic rich HTML display) +display(breakdown) + +# Print text representation +print(breakdown) + +# Access the underlying DataFrame +print(breakdown.df) + +# Convert to plotly format for visualization +import plotly.express as px +data = breakdown.to_plotly() +fig = px.sunburst(**data) +fig.show() +``` + +The `PhaseBreakdown` class provides: +- Hierarchical breakdown of compilation phases +- Automatic calculation of "Other" residual time at each level +- Validation that children don't exceed parent durations +- Multiple output formats (text, DataFrame, Plotly) + +## DataFrame Schema + +The parsed DataFrame contains the following columns from the `-ftime-trace` format: + +- `name`: Event name (function, template instantiation, etc.) +- `ph`: Phase character ('X' for complete, 'B' for begin, 'E' for end, 'i' for instant) +- `ts`: Timestamp in microseconds +- `dur`: Duration in microseconds (for complete events) +- `pid`: Process ID +- `tid`: Thread ID +- `arg_*`: Flattened arguments from the event's `args` field + +### Template Event Columns + +When using `get_template_instantiation_events()`, additional parsed columns are included: + +- `namespace`: Top-level namespace extracted from the template name +- `template_name`: Template name without namespace or parameters +- `full_qualified_name`: Complete namespace::template_name +- `param_count`: Number of template parameters +- `is_ck_type`: Boolean flag for CK library types (namespace starts with 'ck') +- `is_nested`: Boolean flag indicating nested template instantiations + +## Use in Jupyter Notebooks + +The module is designed to work seamlessly in Jupyter notebooks. See `notebooks/file_analysis_example.ipynb` for a complete example workflow that demonstrates: + +- Loading and parsing trace files +- Extracting compilation metadata +- Analyzing phase breakdown with visualizations +- Template instantiation analysis with parsed columns +- Filtering and grouping by namespace +- Identifying CK-specific template costs + +To use in a notebook: + +```python +import sys +from pathlib import Path + +# Add trace_analysis to path +sys.path.insert(0, str(Path.cwd().parent)) + +from trace_analysis import ( + parse_file, + get_metadata, + get_template_instantiation_events, + get_phase_breakdown, +) + +# Load and analyze +df = parse_file('path/to/trace.json') +breakdown = get_phase_breakdown(df) +templates = get_template_instantiation_events(df) + +# Visualize +import plotly.express as px +fig = px.sunburst(**breakdown.to_plotly()) +fig.show() +``` + +## API Reference + +### Core Functions + +- `parse_file(filepath)`: Parse a `-ftime-trace` JSON file into a pandas DataFrame +- `get_metadata(filepath_or_df)`: Extract compilation metadata from trace file or DataFrame + +### Template Analysis + +- `get_template_instantiation_events(df)`: Filter to template instantiation events with parsed template information + +### Phase Breakdown + +- `get_phase_breakdown(df)`: Generate hierarchical compilation phase breakdown +- `PhaseBreakdown`: Class representing phase breakdown with multiple output formats + +## Contributing + +This is an experimental project for analyzing and improving C++ metaprogramming build times. Contributions are welcome! When adding new analysis functions: + +1. Add the function to the appropriate module in `trace_analysis/` +2. Export it in `__init__.py` +3. Update this README with usage examples +4. Consider adding a notebook example if the feature is substantial + +## License + +Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +SPDX-License-Identifier: MIT diff --git a/script/analyze_build/notebooks/file_analysis_example.ipynb b/script/analyze_build/notebooks/file_analysis_example.ipynb new file mode 100644 index 0000000000..e8d1ee3bcd --- /dev/null +++ b/script/analyze_build/notebooks/file_analysis_example.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Template Instantiation Analysis Example\n", + "\n", + "This notebook demonstrates how to use the template analysis functions to understand C++ template instantiation costs in Clang's `-ftime-trace` output.\n", + "\n", + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "# Add parent directory to path\n", + "sys.path.insert(0, str(Path.cwd().parent))\n", + "\n", + "from trace_analysis import (\n", + " parse_file,\n", + " get_template_instantiation_events,\n", + " get_phase_breakdown,\n", + " get_metadata,\n", + ")\n", + "\n", + "import pandas as pd\n", + "from datetime import datetime\n", + "import plotly.express as px\n", + "\n", + "\n", + "# Display settings\n", + "pd.set_option(\"display.max_rows\", 100)\n", + "pd.set_option(\"display.max_columns\", None)\n", + "pd.set_option(\"display.width\", None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Trace File" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load your trace file\n", + "trace_file = Path(\n", + " \"../../../build-trace/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeFiles/device_conv2d_fwd_instance.dir/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp.json\"\n", + ")\n", + "df = parse_file(trace_file)\n", + "\n", + "print(f\"Total events: {len(df):,}\")\n", + "starting_timestamp = datetime.fromtimestamp(df.attrs[\"beginningOfTime\"] / 1e6)\n", + "print(f\"Starting timestamp: {starting_timestamp.strftime('%Y-%m-%d:%H:%M:%S')}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "get_metadata(df)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compilation Overview" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get phase breakdown and display it\n", + "breakdown = get_phase_breakdown(df)\n", + "print(breakdown)\n", + "display(breakdown)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Extract data for plotly charts (sunburst, tree-map, or icicle)\n", + "plotly_data = breakdown.to_plotly()\n", + "fig = px.sunburst(**plotly_data)\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Template Instantiation Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get all template instantiation events (now with parsed columns!)\n", + "template_events = get_template_instantiation_events(df)\n", + "\n", + "print(f\"Total template instantiation events: {len(template_events):,}\")\n", + "print(f\"Total template time: {template_events['dur'].sum() / 1000:.1f} ms\")\n", + "display(template_events)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Examine Parsed Columns\n", + "\n", + "The `get_template_instantiation_events()` function automatically parses the `arg_detail` column into structured fields:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Show the new parsed columns\n", + "print(\"Parsed columns available:\")\n", + "print(\"- namespace: Top-level namespace (e.g., 'std', 'ck')\")\n", + "print(\"- template_name: Template name without parameters\")\n", + "print(\"- full_qualified_name: Full namespace::template_name\")\n", + "print(\"- param_count: Number of template parameters\")\n", + "print(\"- is_ck_type: Boolean indicating CK library types\")\n", + "print(\"- is_nested: Boolean indicating nested templates\")\n", + "print()\n", + "\n", + "# Display sample of parsed data\n", + "template_events[\n", + " [\"namespace\", \"template_name\", \"param_count\", \"is_ck_type\", \"is_nested\", \"dur\"]\n", + "].head(20)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analysis by Namespace" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Group by namespace to see where time is spent\n", + "namespace_summary = (\n", + " template_events.groupby(\"namespace\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"], \"param_count\": \"mean\"})\n", + " .round(2)\n", + ")\n", + "\n", + "namespace_summary.columns = [\"count\", \"total_dur\", \"avg_dur\", \"avg_params\"]\n", + "namespace_summary[\"total_ms\"] = namespace_summary[\"total_dur\"] / 1000\n", + "namespace_summary = namespace_summary.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTemplate Instantiation Time by Namespace:\")\n", + "display(namespace_summary)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CK Library Templates Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Filter to CK types only\n", + "ck_templates = template_events[template_events[\"is_ck_type\"]].copy()\n", + "\n", + "print(f\"CK template instantiations: {len(ck_templates):,}\")\n", + "print(f\"CK template time: {ck_templates['dur'].sum() / 1000:.1f} ms\")\n", + "print(\n", + " f\"Percentage of total template time: {100 * ck_templates['dur'].sum() / template_events['dur'].sum():.1f}%\"\n", + ")\n", + "print()\n", + "\n", + "# Top CK templates by time\n", + "ck_by_name = (\n", + " ck_templates.groupby(\"template_name\")\n", + " .agg({\"dur\": [\"count\", \"sum\", \"mean\"]})\n", + " .round(2)\n", + ")\n", + "ck_by_name.columns = [\"count\", \"total_dur\", \"avg_dur\"]\n", + "ck_by_name[\"total_ms\"] = ck_by_name[\"total_dur\"] / 1000\n", + "ck_by_name = ck_by_name.sort_values(\"total_dur\", ascending=False)\n", + "\n", + "print(\"\\nTop CK Templates by Total Time:\")\n", + "display(ck_by_name.head(20))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/script/analyze_build/requirements.txt b/script/analyze_build/requirements.txt new file mode 100644 index 0000000000..fd99fdba09 --- /dev/null +++ b/script/analyze_build/requirements.txt @@ -0,0 +1,18 @@ +# Build Trace Analysis - Python Dependencies + +# Core data processing +pandas>=2.0.0 +orjson>=3.9.0 + +# Jupyter notebook support +nbformat>=4.2.0 +ipykernel>=6.0.0 + +# Interactive visualizations +plotly>=5.0.0 + +# Static image export from Plotly +kaleido>=0.2.0 + +# Full Jupyter environment (if not using VSCode) +jupyter>=1.0.0 diff --git a/script/analyze_build/trace_analysis/__init__.py b/script/analyze_build/trace_analysis/__init__.py new file mode 100644 index 0000000000..70db321083 --- /dev/null +++ b/script/analyze_build/trace_analysis/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Build Trace Analysis - Core library for analyzing Clang -ftime-trace data. + +This package provides tools to parse and analyze Clang's -ftime-trace JSON output +for build performance analysis. +""" + +from .parse_file import ( + parse_file, + get_metadata, +) + +from .template_analysis import ( + get_template_instantiation_events, +) + +from .phase_breakdown import ( + get_phase_breakdown, + PhaseBreakdown, +) + +__all__ = [ + # Core parsing and filtering + "parse_file", + "get_metadata", + # Template analysis + "get_template_instantiation_events", + # Phase breakdown + "get_phase_breakdown", + "PhaseBreakdown", +] diff --git a/script/analyze_build/trace_analysis/parse_file.py b/script/analyze_build/trace_analysis/parse_file.py new file mode 100644 index 0000000000..24d71e4eb8 --- /dev/null +++ b/script/analyze_build/trace_analysis/parse_file.py @@ -0,0 +1,356 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Parse a single Clang -ftime-trace JSON file into a Pandas DataFrame. + +This module provides fast parsing of Clang's -ftime-trace output using orjson +for performance. The JSON file is typically a single-line array of trace events. +""" + +import orjson +import pandas as pd +from pathlib import Path +from typing import Union, Optional +from datetime import datetime +from dataclasses import dataclass + + +# Expected schema for trace event DataFrames with optimized dtypes +# This enforces strict column validation and memory-efficient types +# The memory usage is dominated by arg detail, but we optimize each series. +TRACE_EVENT_DTYPES = { + "pid": "int32", # Process ID (max observed: ~2.3M, fits in int32) + "tid": "int32", # Thread ID (max observed: ~2.3M, fits in int32) + "ts": "int64", # Timestamp in microseconds (requires int64 for epoch times) + "cat": "category", # Category (low cardinality, use categorical) + "ph": "category", # Phase type (very low cardinality: X, B, E, i, etc.) + "id": "int64", # Event ID + "name": "category", # Event name (medium cardinality, use categorical) + "dur": "int64", # Duration in microseconds (max 10 days = 864B μs, requires int64) + "arg_detail": "string", # Detail string (high cardinality, keep as string) + "arg_count": "int64", # Argument count + "arg_avg ms": "int64", # Average milliseconds + "arg_name": "category", # Argument name (medium cardinality, use categorical) +} + + +@dataclass +class FileMetadata: + """ + Processed metadata with computed fields for compilation analysis. + + This extends the raw metadata with derived values like formatted timestamps + and converted time units for convenience. + + Attributes: + source_file: Main .cpp/.c file being compiled + time_granularity: Time unit used in trace (always "microseconds" for Clang) + beginning_of_time: Epoch timestamp in microseconds from JSON root + execute_compiler_ts: Timestamp of ExecuteCompiler event (microseconds) + execute_compiler_dur: Duration of ExecuteCompiler event (microseconds) + total_wall_time_us: Total compilation time in microseconds (same as execute_compiler_dur) + total_wall_time_s: Total compilation time in seconds (computed from microseconds) + wall_start_time: Wall clock start time in microseconds since epoch (computed) + wall_end_time: Wall clock end time in microseconds since epoch (computed) + wall_start_datetime: Human-readable start time string (formatted) + wall_end_datetime: Human-readable end time string (formatted) + """ + + source_file: Optional[str] = None + time_granularity: str = "microseconds" + beginning_of_time: Optional[int] = None + execute_compiler_ts: Optional[int] = None + execute_compiler_dur: Optional[int] = None + total_wall_time_us: Optional[int] = None + total_wall_time_s: Optional[float] = None + wall_start_time: Optional[int] = None + wall_end_time: Optional[int] = None + wall_start_datetime: Optional[str] = None + wall_end_datetime: Optional[str] = None + + def __repr__(self): + # auto-generate pretty lines + fields = "\n".join( + f" {name} = {value!r}" for name, value in self.__dict__.items() + ) + return f"{self.__class__.__name__}(\n{fields}\n)" + + +def parse_file(filepath: Union[str, Path]) -> pd.DataFrame: + """ + Parse a Clang -ftime-trace JSON file into a Pandas DataFrame. + + The -ftime-trace format is a JSON array of trace events. Each event contains + fields like name, phase (ph), timestamp (ts), duration (dur), process/thread IDs, + and optional arguments (args). + + The beginningOfTime value from the JSON structure is automatically extracted + and stored in df.attrs['beginningOfTime']. Use get_metadata(df) to get + processed metadata with event-derived fields and computed values. + + Args: + filepath: Path to the -ftime-trace JSON file + + Returns: + DataFrame with columns for each event field. Nested 'args' are flattened + with an 'arg_' prefix. The beginningOfTime value is stored in + df.attrs['beginningOfTime']. + + Raises: + FileNotFoundError: If the file doesn't exist + ValueError: If the JSON is invalid or empty + + Examples: + >>> df = parse_file('build/trace.json') + >>> df[['name', 'dur']].head() + >>> + >>> # Access processed metadata + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> + >>> # Access beginningOfTime directly if needed + >>> beginning = df.attrs.get('beginningOfTime') + >>> print(f"Beginning of time: {beginning}") + """ + filepath = Path(filepath) + + if not filepath.exists(): + raise FileNotFoundError(f"Trace file not found: {filepath}") + + # Read and parse JSON using orjson for speed + with open(filepath, "rb") as f: + data = orjson.loads(f.read()) + + if not data: + raise ValueError(f"Empty trace data in file: {filepath}") + + # Handle both formats: direct array or {"traceEvents": [...]} + if isinstance(data, dict): + if "traceEvents" in data: + events = data["traceEvents"] + else: + raise ValueError( + f"Expected 'traceEvents' key in JSON object, got keys: {list(data.keys())}" + ) + elif isinstance(data, list): + events = data + else: + raise ValueError(f"Expected JSON array or object, got {type(data).__name__}") + + # Convert to DataFrame + df = pd.DataFrame(events) + + if df.empty: + raise ValueError(f"No trace events found in file: {filepath}") + + # Flatten 'args' column if it exists + if "args" in df.columns: + df = _flatten_args(df) + + # Validate schema: check for missing columns + expected_columns = set(TRACE_EVENT_DTYPES.keys()) + actual_columns = set(df.columns) + + missing_columns = expected_columns - actual_columns + if missing_columns: + raise ValueError( + f"Missing expected columns in trace data: {sorted(missing_columns)}" + ) + + # Validate schema: check for unexpected columns + unexpected_columns = actual_columns - expected_columns + if unexpected_columns: + raise ValueError( + f"Unexpected columns found in trace data: {sorted(unexpected_columns)}" + ) + + # Apply optimized dtypes with strict type enforcement + for col, dtype in TRACE_EVENT_DTYPES.items(): + if dtype in ("int64", "int32"): + # Fill missing values with 0 for integer columns, then convert to specified int type + df[col] = df[col].fillna(0).astype(dtype) + elif dtype == "category": + # Convert to categorical for memory efficiency with repeated values + df[col] = df[col].astype("category") + elif dtype == "string": + # Convert to pandas string dtype for memory efficiency + df[col] = df[col].astype("string") + else: + raise ValueError(f"Unsupported dtype '{dtype}' for column '{col}'") + + # Extract and store beginningOfTime in DataFrame attributes + df.attrs["beginningOfTime"] = ( + data.get("beginningOfTime") if isinstance(data, dict) else None + ) + + return df + + +def _flatten_args(df: pd.DataFrame) -> pd.DataFrame: + """ + Flatten the 'args' column into separate columns with 'arg_' prefix. + + The 'args' field in trace events contains additional metadata as a dictionary. + This function extracts those key-value pairs into separate columns. + + Args: + df: DataFrame with an 'args' column containing dictionaries + + Returns: + DataFrame with flattened args columns and original 'args' column removed + """ + # Extract args into separate DataFrame + args_data = [] + for idx, row in df.iterrows(): + args = row.get("args", {}) + if isinstance(args, dict): + args_data.append(args) + else: + args_data.append({}) + + if args_data: + args_df = pd.DataFrame(args_data) + # Prefix all args columns with 'arg_' + args_df.columns = [f"arg_{col}" for col in args_df.columns] + + # Drop original args column and concatenate flattened args + df = df.drop(columns=["args"]) + df = pd.concat([df, args_df], axis=1) + + return df + + +def _normalize_source_path(file_path: str) -> str: + """ + Normalize a source file path to be relative to composable_kernel if present. + + If 'composable_kernel' appears in the path, returns the path starting from + 'composable_kernel/'. Otherwise, returns the original path unchanged. + + Args: + file_path: Full filesystem path to a source file + + Returns: + Normalized path starting from composable_kernel, or original path if + composable_kernel is not found + + Examples: + >>> _normalize_source_path('/home/user/composable_kernel/include/ck/tensor.hpp') + 'composable_kernel/include/ck/tensor.hpp' + >>> _normalize_source_path('/usr/include/vector') + '/usr/include/vector' + """ + path = Path(file_path) + parts = path.parts + + # Find the last occurrence of 'composable_kernel' in the path + for i in range(len(parts) - 1, -1, -1): + if parts[i] == "composable_kernel": + # Return path from composable_kernel onwards + return str(Path(*parts[i:])) + + # If composable_kernel not found, return original path + return file_path + + +def get_metadata(df: pd.DataFrame) -> FileMetadata: + """ + Extract and process compilation metadata from a DataFrame. + + This function processes events from the DataFrame to extract compilation + information, then computes derived fields like formatted timestamps and + converted time units. + + Args: + df: DataFrame returned by parse_file() with beginningOfTime in its .attrs + + Returns: + FileMetadata instance with both raw and computed fields: + - source_file: Main .cpp/.c file being compiled (from events) + - time_granularity: Time unit used in trace ("microseconds") + - beginning_of_time: Epoch timestamp in microseconds from JSON root + - execute_compiler_ts: Timestamp of ExecuteCompiler event (from events) + - execute_compiler_dur: Duration of ExecuteCompiler event (from events) + - total_wall_time_us: Total compilation time in microseconds + - total_wall_time_s: Total compilation time in seconds (computed) + - wall_start_time: Wall clock start time (computed) + - wall_end_time: Wall clock end time (computed) + - wall_start_datetime: Human-readable start time (formatted) + - wall_end_datetime: Human-readable end time (formatted) + + Examples: + >>> df = parse_file('trace.json') + >>> metadata = get_metadata(df) + >>> print(f"Compiled: {metadata.source_file}") + >>> print(f"Duration: {metadata.total_wall_time_s:.2f}s") + >>> print(f"Started: {metadata.wall_start_datetime}") + """ + # Extract beginningOfTime from DataFrame attributes + beginning_of_time = None + if hasattr(df, "attrs"): + beginning_of_time = df.attrs.get("beginningOfTime") + + # Initialize metadata with beginningOfTime from JSON structure + metadata = FileMetadata(beginning_of_time=beginning_of_time) + + # Process events to extract ExecuteCompiler timing information + if "name" in df.columns: + execute_compiler = df[df["name"] == "ExecuteCompiler"] + if not execute_compiler.empty: + # Get the first ExecuteCompiler event + event = execute_compiler.iloc[0] + if "ts" in event: + metadata.execute_compiler_ts = event["ts"] + if "dur" in event: + metadata.execute_compiler_dur = event["dur"] + + # Process events to find the main source file being compiled + if "name" in df.columns and "arg_detail" in df.columns: + # Look for ParseDeclarationOrFunctionDefinition events with .cpp or .c files + source_extensions = (".cpp", ".cc", ".cxx", ".c") + parse_events = df[df["name"] == "ParseDeclarationOrFunctionDefinition"] + + for _, event in parse_events.iterrows(): + detail = event.get("arg_detail", "") + if detail: + # Extract file path (may include line:column info) + file_path = str(detail).split(":")[0] + + # Check if it's a source file (not a header) + if any(file_path.endswith(ext) for ext in source_extensions): + metadata.source_file = _normalize_source_path(file_path) + break + + # Compute derived fields + if metadata.execute_compiler_dur is not None: + metadata.total_wall_time_us = metadata.execute_compiler_dur + metadata.total_wall_time_s = metadata.execute_compiler_dur / 1_000_000.0 + + # Calculate wall clock times if we have the necessary data + if ( + metadata.beginning_of_time is not None + and metadata.execute_compiler_ts is not None + and metadata.execute_compiler_dur is not None + ): + metadata.wall_start_time = ( + metadata.beginning_of_time + metadata.execute_compiler_ts + ) + metadata.wall_end_time = ( + metadata.wall_start_time + metadata.execute_compiler_dur + ) + + # Convert to human-readable datetime strings + try: + start_dt = datetime.fromtimestamp(metadata.wall_start_time / 1_000_000.0) + end_dt = datetime.fromtimestamp(metadata.wall_end_time / 1_000_000.0) + metadata.wall_start_datetime = start_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[ + :-3 + ] + metadata.wall_end_datetime = end_dt.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + except (OSError, ValueError): + # Handle invalid timestamps gracefully + pass + + return metadata diff --git a/script/analyze_build/trace_analysis/phase_breakdown.py b/script/analyze_build/trace_analysis/phase_breakdown.py new file mode 100644 index 0000000000..773ba06622 --- /dev/null +++ b/script/analyze_build/trace_analysis/phase_breakdown.py @@ -0,0 +1,354 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Phase breakdown analysis for Clang -ftime-trace data. + +This module provides hierarchical breakdown of compilation phases using +the pre-aggregated "Total" events from Clang's -ftime-trace output. + +The data is returned as a PhaseBreakdown object with rich display and +analysis capabilities optimized for Jupyter notebooks. +""" + +import pandas as pd +from collections import namedtuple +from typing import Optional + + +# Lightweight namedtuple for iteration +Phase = namedtuple("Phase", ["name", "depth", "duration", "duration_ms", "percentage"]) + + +class PhaseBreakdown: + """ + Wrapper for compilation phase breakdown with notebook-friendly API. + + Provides hierarchical view of compilation phases from Clang -ftime-trace, + with rich display, filtering, and visualization capabilities. + + Examples: + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter + >>> breakdown + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend + >>> + >>> # Get metrics + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Top N analysis + >>> breakdown.top(10) + >>> breakdown.frontend.top(5) + >>> + >>> # Visualization + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Iteration + >>> for phase in breakdown: + >>> print(f"{phase.name}: {phase.duration_ms:.1f}ms") + """ + + def __init__(self, df: pd.DataFrame): + """ + Initialize from phase breakdown DataFrame. + + Args: + df: DataFrame with columns name, parent, depth, duration + """ + if df.empty: + self._df = pd.DataFrame(columns=["name", "parent", "depth", "duration"]) + self._total_time = 0 + else: + self._df = df + self._total_time = self._get_total_time() + + def __repr__(self) -> str: + """Simple text representation for console.""" + if self._df.empty: + return "PhaseBreakdown(empty)" + n_phases = len(self._df) + return f"PhaseBreakdown({n_phases} phases, {self._total_time:.1f}ms total)" + + def _repr_html_(self) -> str: + """Rich HTML representation for Jupyter notebooks.""" + if self._df.empty: + return "
PhaseBreakdown(empty)
" + return self.to_dataframe()._repr_html_() + + @property + def df(self) -> pd.DataFrame: + """ + Access underlying DataFrame. + + Returns: + DataFrame with columns name, parent, depth, duration + """ + return self._df + + def to_dataframe(self, show_percentages: bool = True) -> pd.DataFrame: + """ + Format as DataFrame for display. + + Creates a nicely formatted DataFrame suitable for Jupyter notebook display. + + Args: + show_percentages: Include percentage of total time + + Returns: + DataFrame with formatted columns + """ + return self._format_dataframe(show_percentages) + + def to_plotly(self) -> dict: + """ + Convert to plotly hierarchical visualization format. + + Returns a dictionary with data_frame, values, and path that can be directly + used with plotly.express sunburst, treemap, or icicle charts. + + Returns: + Dictionary with keys: data_frame, values, path, branchvalues + + Example: + >>> data = breakdown.to_plotly() + >>> import plotly.express as px + >>> + >>> # Create sunburst chart + >>> fig = px.sunburst(**data) + >>> fig.show() + >>> + >>> # Create treemap chart + >>> fig = px.treemap(**data) + >>> fig.show() + >>> + >>> # Create icicle chart + >>> fig = px.icicle(**data) + >>> fig.show() + """ + return self._build_plotly_data() + + # Internal helper methods + + def _get_total_time(self) -> int: + """Get total time from root ExecuteCompiler event.""" + root = self._df[self._df["depth"] == 0] + if root.empty: + return 0 + return int(root.iloc[0]["duration"]) + + def _format_dataframe(self, show_percentages: bool) -> pd.DataFrame: + """Format phase breakdown as DataFrame.""" + if self._df.empty: + return pd.DataFrame() + + display_rows = [] + for _, row in self._df.iterrows(): + duration_ms = row["duration"] / 1000.0 + display_row = { + "Name": row["name"], + "Parent": row["parent"] if row["parent"] else "(root)", + "Depth": row["depth"], + "Duration (ms)": duration_ms, + } + if show_percentages and self._total_time > 0: + pct = row["duration"] / self._total_time * 100 + display_row["% of Total"] = pct + display_rows.append(display_row) + + display_df = pd.DataFrame(display_rows) + + if show_percentages: + display_df["% of Total"] = display_df["% of Total"].round(1) + + return display_df + + def _build_plotly_data(self) -> dict: + """Convert to plotly hierarchical visualization format.""" + return { + "data_frame": self._df, + "names": "name", + "parents": "parent", + "values": "duration", + "branchvalues": "total", + } + + +# Hierarchical phase specification +# There are over 100 totals in the JSON file, but a lot of them overlap. +# If the children total more than their parent, we will throw a ValueError. +# +# The hierarchy is specified as a nested dictionary where: +# - Keys are phase names (matching "Total " events in the trace) +# - Values are dictionaries of child phases (or empty dict {} for leaf nodes) +# - Empty string "" as a key means "calculate Other as residual" +# +# This structure supports arbitrary nesting depth. +PHASE_HIERARCHY = { + "ExecuteCompiler": { + "Frontend": { + "InstantiateFunction": {}, + }, + "Backend": { + "Optimizer": {}, + "CodeGenPasses": {}, + }, + } +} + + +def get_phase_breakdown(df: pd.DataFrame) -> PhaseBreakdown: + """ + Get hierarchical breakdown of compilation phases. + + Returns a PhaseBreakdown object with rich display and analysis methods, + using the pre-aggregated "Total" events from Clang's -ftime-trace output + for accurate statistics. + + All durations are in microseconds. + + The hierarchy is defined by the PHASE_HIERARCHY constant and supports + arbitrary nesting depth. The tree is traversed recursively to build + the phase breakdown. + + Args: + df: DataFrame from parse_file() + + Returns: + PhaseBreakdown object with rich display and analysis methods + + Raises: + ValueError: If required Total events are missing or if calculated + "Other" values are negative (indicating data inconsistency) + + Examples: + >>> df = parse_file('trace.json') + >>> breakdown = get_phase_breakdown(df) + >>> + >>> # Display in Jupyter (automatic) + >>> breakdown + >>> + >>> # Get total compilation time + >>> print(f"Total: {breakdown.total_ms:.1f}ms") + >>> + >>> # Access specific phases + >>> breakdown['InstantiateFunction'] + >>> breakdown.frontend + >>> breakdown.backend.top(5) + >>> + >>> # Visualize + >>> import plotly.express as px + >>> data = breakdown.to_plotly() + >>> fig = px.sunburst(**data) + >>> fig.show() + """ + if "name" not in df.columns or "dur" not in df.columns: + raise ValueError("DataFrame missing required 'name' or 'dur' columns") + + # Pre-filter to Total events for efficient lookup + total_events = df[df["name"].str.startswith("Total ", na=False)].copy() + total_events["phase"] = total_events["name"].str.removeprefix("Total ") + + def get_duration(phase_name: str) -> Optional[int]: + """Get duration in microseconds from a Total event.""" + matches = total_events[total_events["phase"] == phase_name] + if matches.empty: + return None + return int(matches.iloc[0]["dur"]) + + def process_node( + node_name: str, + parent_name: str, + depth: int, + children_spec: dict, + ) -> list[dict]: + """ + Recursively process a node and its children in the phase hierarchy. + + Args: + node_name: Name of the current phase node + parent_name: Name of the parent phase (empty string for root) + depth: Current depth in the tree (0 for root) + children_spec: Dictionary of child phases to process + + Returns: + List of row dictionaries for this node and all descendants + + Raises: + ValueError: If phase not found or children exceed parent duration + """ + # Get duration for this node + node_duration = get_duration(node_name) + if node_duration is None: + raise ValueError(f"No Total {node_name} event found in trace") + + # Add current node + rows = [ + { + "name": node_name, + "parent": parent_name, + "depth": depth, + "duration": node_duration, + } + ] + + if not children_spec: + return rows + + # Process all children recursively + children_total = 0 + for child_name, grandchildren_spec in children_spec.items(): + if child_name == "": + # Empty string means "Other" - skip for now, calculate as residual + continue + + # Recursively process this child and its descendants + child_rows = process_node( + child_name, node_name, depth + 1, grandchildren_spec + ) + rows.extend(child_rows) + + # Track total duration of direct children only (not grandchildren) + children_total += child_rows[0]["duration"] + + # Calculate and add "Other" if there's unaccounted time + other_duration = node_duration - children_total + if other_duration < 0: + raise ValueError( + f"{node_name} children total ({children_total}) " + f"exceeds parent total ({node_duration})" + ) + + if other_duration > 0: + rows.append( + { + "name": f"{node_name}_Other", + "parent": node_name, + "depth": depth + 1, + "duration": other_duration, + } + ) + + return rows + + # Start recursive traversal from root + root_name = "ExecuteCompiler" + if root_name not in PHASE_HIERARCHY: + raise ValueError(f"Root phase '{root_name}' not found in PHASE_HIERARCHY") + + all_rows = process_node( + root_name, + "", # Root has no parent + 0, # Root is at depth 0 + PHASE_HIERARCHY[root_name], + ) + + breakdown_df = pd.DataFrame(all_rows) + return PhaseBreakdown(breakdown_df) diff --git a/script/analyze_build/trace_analysis/template_analysis.py b/script/analyze_build/trace_analysis/template_analysis.py new file mode 100644 index 0000000000..ef483f6f53 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_analysis.py @@ -0,0 +1,80 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template instantiation analysis for Clang -ftime-trace data. + +This module provides specialized functions for analyzing C++ template +instantiation costs from Clang's -ftime-trace output. +""" + +import pandas as pd +from .template_parser import parse_template_detail + + +def get_template_instantiation_events(df: pd.DataFrame) -> pd.DataFrame: + """ + Filter to template instantiation events and parse arg_detail into structured columns. + + Returns events for: + - InstantiateFunction: Function template instantiations + - InstantiateClass: Class template instantiations + + The returned DataFrame includes parsed columns from arg_detail: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Args: + df: DataFrame from parse_file() + + Returns: + Filtered DataFrame containing template instantiation events with parsed columns + + Example: + >>> df = parse_file('trace.json') + >>> templates = get_template_instantiation_events(df) + >>> templates.sort_values('dur', ascending=False).head(10) + >>> # Filter to CK types only + >>> ck_templates = templates[templates['is_ck_type']] + >>> # Group by template name + >>> templates.groupby('template_name')['dur'].sum() + """ + # Filter to template instantiation events + filtered_df = ( + df[ + df["name"].isin( + [ + "InstantiateClass", + "InstantiateFunction", + ] + ) + ] + .drop( + columns=[ + "arg_avg ms", + "arg_count", + "arg_name", + "cat", + "id", + "ph", + "pid", + "tid", + ] + ) + .reset_index(drop=True) + ) + + # Parse arg_detail into structured columns + parsed_data = filtered_df["arg_detail"].apply(parse_template_detail) + + # Convert list of dicts to DataFrame and join with original + parsed_df = pd.DataFrame(parsed_data.tolist()) + + # Combine with original data + result_df = pd.concat([filtered_df, parsed_df], axis=1) + + return result_df diff --git a/script/analyze_build/trace_analysis/template_parser.py b/script/analyze_build/trace_analysis/template_parser.py new file mode 100644 index 0000000000..2551465bd4 --- /dev/null +++ b/script/analyze_build/trace_analysis/template_parser.py @@ -0,0 +1,301 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Template detail string parser for C++ template instantiations. + +This module provides functions to parse the arg_detail strings from +Clang's -ftime-trace output into structured components. +""" + +import re +from typing import Dict + + +def parse_template_detail(detail_str: str) -> Dict[str, any]: + """ + Parse a template detail string into structured components. + + Args: + detail_str: The arg_detail string from -ftime-trace + + Returns: + Dictionary with parsed fields: + - namespace: Top-level namespace (e.g., 'std', 'ck') + - template_name: Template name without parameters + - full_qualified_name: Full namespace::template_name + - param_count: Number of template parameters + - is_ck_type: Boolean indicating if this is a CK library type + - is_nested: Boolean indicating if contains nested templates + + Example: + >>> parse_template_detail('std::basic_string') + { + 'namespace': 'std', + 'template_name': 'basic_string', + 'full_qualified_name': 'std::basic_string', + 'param_count': 1, + 'is_ck_type': False, + 'is_nested': False + } + """ + # Handle empty or invalid strings + if not detail_str or not isinstance(detail_str, str): + return _empty_result() + + # Remove surrounding quotes if present + detail_str = detail_str.strip('"') + + # Extract components + namespace = extract_namespace(detail_str) + template_name = extract_template_name(detail_str) + full_qualified_name = extract_full_qualified_name(detail_str) + param_count = count_template_params(detail_str) + is_ck = is_ck_template(detail_str) + is_nested = is_nested_template(detail_str) + + return { + "namespace": namespace, + "template_name": template_name, + "full_qualified_name": full_qualified_name, + "param_count": param_count, + "is_ck_type": is_ck, + "is_nested": is_nested, + } + + +def extract_namespace(detail_str: str) -> str: + """ + Extract the top-level namespace from a template detail string. + + Args: + detail_str: The template detail string + + Returns: + The top-level namespace, or empty string if none found + + Example: + >>> extract_namespace('std::basic_string') + 'std' + >>> extract_namespace('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find first :: separator + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)::", detail_str) + if match: + return match.group(1) + + # No namespace found - check if it's a simple type + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_template_name(detail_str: str) -> str: + """ + Extract the template name without namespace or parameters. + + Args: + detail_str: The template detail string + + Returns: + The template name without namespace or parameters + + Example: + >>> extract_template_name('std::basic_string') + 'basic_string' + >>> extract_template_name('ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<...>') + 'GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the last component before < or end of string + # This handles nested namespaces like ck::tensor_operation::device::DeviceConv2d + match = re.search(r"::([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + # No :: found, try to get name before < + match = re.match(r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def extract_full_qualified_name(detail_str: str) -> str: + """ + Extract the full qualified name (namespace::...::template_name). + + Args: + detail_str: The template detail string + + Returns: + The full qualified name without template parameters + + Example: + >>> extract_full_qualified_name('std::basic_string') + 'std::basic_string' + >>> extract_full_qualified_name('ck::tensor_operation::device::DeviceConv2d<...>') + 'ck::tensor_operation::device::DeviceConv2d' + """ + if not detail_str: + return "" + + # Remove quotes + detail_str = detail_str.strip('"') + + # Match everything up to the first < or end of string + match = re.match(r"^([a-zA-Z_:][a-zA-Z0-9_:]*)\s*(?:<|$)", detail_str) + if match: + return match.group(1) + + return "" + + +def count_template_params(detail_str: str) -> int: + """ + Count the number of top-level template parameters. + + This counts commas at the top level of template brackets, + not commas inside nested templates. + + Args: + detail_str: The template detail string + + Returns: + Number of template parameters, or 0 if not a template + + Example: + >>> count_template_params('std::basic_string') + 1 + >>> count_template_params('std::tuple') + 3 + """ + if not detail_str or "<" not in detail_str: + return 0 + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return 0 + + # Track bracket depth to only count top-level commas + depth = 0 + param_count = 1 # Start with 1 (if there's a <, there's at least one param) + in_template = False + + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + in_template = True + elif char == ">": + depth -= 1 + if depth == 0: + # We've closed the outermost template + break + elif char == "," and depth == 1: + # Top-level comma + param_count += 1 + + return param_count if in_template else 0 + + +def is_ck_template(detail_str: str) -> bool: + """ + Check if this is a CK library template. + + Args: + detail_str: The template detail string + + Returns: + True if this is a CK library type, False otherwise + + Example: + >>> is_ck_template('ck::tensor_operation::device::DeviceConv2d<...>') + True + >>> is_ck_template('std::basic_string') + False + """ + if not detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Check if it starts with ck:: or contains ::ck:: + return detail_str.startswith("ck::") or "::ck::" in detail_str + + +def is_nested_template(detail_str: str) -> bool: + """ + Check if this template contains nested template instantiations. + + Args: + detail_str: The template detail string + + Returns: + True if contains nested templates, False otherwise + + Example: + >>> is_nested_template('std::vector') + False + >>> is_nested_template('std::vector') + True + """ + if not detail_str or "<" not in detail_str: + return False + + # Remove quotes + detail_str = detail_str.strip('"') + + # Find the template parameter section + start = detail_str.find("<") + if start == -1: + return False + + # Look for nested < after the first one + depth = 0 + for i in range(start, len(detail_str)): + char = detail_str[i] + + if char == "<": + depth += 1 + if depth > 1: + # Found a nested template + return True + elif char == ">": + depth -= 1 + if depth == 0: + break + + return False + + +def _empty_result() -> Dict[str, any]: + """Return an empty result dictionary with default values.""" + return { + "namespace": "", + "template_name": "", + "full_qualified_name": "", + "param_count": 0, + "is_ck_type": False, + "is_nested": False, + } From 42a731b791e72d4ea5f270be905e6fa1eb524626 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 12:28:59 -0500 Subject: [PATCH 12/28] Updating failure patterns to be more reliable and adding tests to verify they are caught in the logs --- Jenkinsfile | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index f3a597e404..712602e532 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /docker login failed/, description: "Docker login failed"], + [pattern: /.docker login failed./, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /GPU not found/, description: "GPU not found"], + [pattern: /.GPU not found./, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] @@ -1290,6 +1290,13 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" + // Todo: Remove test examples + echo "GPU not found" + echo "Testing GPU not found" + echo "GPU not found Testing" + echo "docker login failed" + echo "Testing docker login failed" + echo "docker login failed Testing" } } } From 786965b95ed049e7ba2f0e6f00875a2634db90f9 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 12:47:27 -0500 Subject: [PATCH 13/28] Fixing Jenkinsfile too large error --- Jenkinsfile | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 712602e532..cd7678df1a 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -34,6 +34,18 @@ def checkForPattern(pattern, log) { return [found: false, matchedLine: "", context: ""] } +def testLog() { + // Todo: Remove test examples + sh """ + echo "GPU not found" + echo "Testing GPU not found" + echo "GPU not found Testing" + echo "docker login failed" + echo "Testing docker login failed" + echo "docker login failed Testing" + """ +} + // Scan the build logs for failures and send notifications. def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. @@ -1290,13 +1302,7 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" - // Todo: Remove test examples - echo "GPU not found" - echo "Testing GPU not found" - echo "GPU not found Testing" - echo "docker login failed" - echo "Testing docker login failed" - echo "docker login failed Testing" + testLog() } } } From 95768d1b22697488f793ab90fbc7ca8e241aa6e7 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 13:02:25 -0500 Subject: [PATCH 14/28] Adding forcing failure to test notifications --- Jenkinsfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Jenkinsfile b/Jenkinsfile index cd7678df1a..5e1a5af3e4 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -44,6 +44,7 @@ def testLog() { echo "Testing docker login failed" echo "docker login failed Testing" """ + error("Forcing failure to test notifications") } // Scan the build logs for failures and send notifications. From 58e1d032441fed82d33240f132168ad94bcba476 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 13:56:47 -0500 Subject: [PATCH 15/28] Removing working cases to test other failure examples --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5e1a5af3e4..1c50698d3c 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,10 +39,8 @@ def testLog() { sh """ echo "GPU not found" echo "Testing GPU not found" - echo "GPU not found Testing" echo "docker login failed" echo "Testing docker login failed" - echo "docker login failed Testing" """ error("Forcing failure to test notifications") } From 6c596b95535fffcacc2d4fadb8199ab5d00d7853 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:21:06 -0500 Subject: [PATCH 16/28] Testing a pattern to support all text variations --- Jenkinsfile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 1c50698d3c..5ae56929dd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -39,8 +39,10 @@ def testLog() { sh """ echo "GPU not found" echo "Testing GPU not found" + echo "GPU not found Testing" echo "docker login failed" echo "Testing docker login failed" + echo "docker login failed Testing" """ error("Forcing failure to test notifications") } @@ -50,10 +52,10 @@ def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. def failurePatterns = [ [pattern: /login attempt to .* failed with status: 401 Unauthorized/, description: "Docker registry authentication failed"], - [pattern: /.docker login failed./, description: "Docker login failed"], + [pattern: /(.*)docker login failed(.*)/, description: "Docker login failed"], [pattern: /HTTP request sent .* 404 Not Found/, description: "HTTP request failed with 404"], [pattern: /cat: .* No such file or directory/, description: "GPU not found"], - [pattern: /.GPU not found./, description: "GPU not found"], + [pattern: /(.*)GPU not found(.*)/, description: "GPU not found"], [pattern: /Could not connect to Redis at .* Connection timed out/, description: "Redis connection timed out"] ] From 1397924c21603123c14d0db3242532eff666eae2 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:25:21 -0500 Subject: [PATCH 17/28] Removed working tests. Validating remaining tests. --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 5ae56929dd..d860dc0fca 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -37,10 +37,8 @@ def checkForPattern(pattern, log) { def testLog() { // Todo: Remove test examples sh """ - echo "GPU not found" echo "Testing GPU not found" echo "GPU not found Testing" - echo "docker login failed" echo "Testing docker login failed" echo "docker login failed Testing" """ From 402f21d0a6ccf22c64f252f84768e046690b8810 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:27:18 -0500 Subject: [PATCH 18/28] Removed working tests. Validating remaining tests. --- Jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index d860dc0fca..49949d8851 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -37,9 +37,7 @@ def checkForPattern(pattern, log) { def testLog() { // Todo: Remove test examples sh """ - echo "Testing GPU not found" echo "GPU not found Testing" - echo "Testing docker login failed" echo "docker login failed Testing" """ error("Forcing failure to test notifications") From 8654c0628f83261d3dd64cfb4ec80e9dd2b29fa5 Mon Sep 17 00:00:00 2001 From: Andrew Clark Date: Fri, 23 Jan 2026 14:29:13 -0500 Subject: [PATCH 19/28] Finished testing failure types. Removed testing code. --- Jenkinsfile | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 49949d8851..1a8be258bd 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -34,15 +34,6 @@ def checkForPattern(pattern, log) { return [found: false, matchedLine: "", context: ""] } -def testLog() { - // Todo: Remove test examples - sh """ - echo "GPU not found Testing" - echo "docker login failed Testing" - """ - error("Forcing failure to test notifications") -} - // Scan the build logs for failures and send notifications. def sendFailureNotifications() { // Error patterns to scan build logs for specific failure types and send detailed notifications. @@ -1299,7 +1290,6 @@ pipeline { script { env.SHOULD_RUN_CI = String.valueOf(params.FORCE_CI.toBoolean() || shouldRunCICheck()) echo "SHOULD_RUN_CI: ${env.SHOULD_RUN_CI}" - testLog() } } } From cc75948d1c7f732d102c8e31dc007a2ccd07761f Mon Sep 17 00:00:00 2001 From: Robin Voetter Date: Mon, 26 Jan 2026 23:50:15 +0100 Subject: [PATCH 20/28] [CK_BUILDER] conv bwd weight testing (#3618) * ck-builder: restructure testing conv In order to prepare for bwd of conv testing, this commit moves some files and types around so that we can reuse ckt::Args for both forward and backwards convolution. * ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp This will allow us to more easily include fwd.hpp from backwards definitions, which is required for initializing bwd values. * ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3 Turns out that the supplied layout isn't actually supported... * ck-builder: ck and reference conv integration for bwd weight * ck-builder: ck bwd weight execution test * ck-builder: ckt::run support for ck-tile bwd weight * ck-builder: ck tile bwd weight execution test * ck-builder: extra debug printing in MatchesReference * ck-builder: make ckt::run return RunResult This type is more convenient than std::tuple, as it will allow us to use google test matchers with this in the future. * ck-builder: RunResult matcher Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error message about how and why running an algorithm failed. * ck-builder: doc fixes * ck-builder: add missing headers --- .../testing/{conv_fwd.hpp => conv/args.hpp} | 64 +--- .../builder/testing/conv/bwd_weight.hpp | 71 +++++ .../builder/testing/conv/bwd_weight_ck.hpp | 276 ++++++++++++++++++ .../ck_tile.hpp} | 92 ++++-- .../ck_tile/builder/testing/conv/fwd.hpp | 69 +++++ .../{conv_fwd_ck.hpp => conv/fwd_ck.hpp} | 58 ++-- .../builder/testing/conv/reference.hpp | 137 +++++++++ .../builder/testing/conv_fwd_reference.hpp | 88 ------ .../builder/testing/tensor_initialization.hpp | 1 + .../ck_tile/builder/testing/testing.hpp | 62 +++- .../builder/testing/testing_reflect.hpp | 2 + experimental/builder/test/CMakeLists.txt | 2 +- ...st_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp | 59 +++- .../conv/ck/test_ckb_conv_fwd_2d_fp16.cpp | 13 +- .../test_ckb_conv_bwd_weight_2d_fp16_v3.cpp | 94 ++++-- .../conv/ck_tile/test_ckb_conv_fwd_e2e.cpp | 13 +- .../builder/test/test_testing_utils.cpp | 17 ++ experimental/builder/test/testing_utils.cpp | 18 ++ experimental/builder/test/testing_utils.hpp | 32 ++ .../builder/test/unit_conv_fwd_testing.cpp | 2 +- experimental/builder/test/unit_validation.cpp | 5 +- .../instances/instance_includes.inc | 3 +- .../instances/instance_run.inc | 8 +- .../grouped_convolution_forward_tile_algs.hpp | 9 +- .../grouped_convolution_signatures.hpp | 2 +- .../src/profile_grouped_conv_fwd_tile.cpp | 2 +- .../test_grouped_convnd_fwd_tile.cpp | 2 +- 27 files changed, 939 insertions(+), 262 deletions(-) rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd.hpp => conv/args.hpp} (82%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd_ck_tile.hpp => conv/ck_tile.hpp} (52%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp rename experimental/builder/include/ck_tile/builder/testing/{conv_fwd_ck.hpp => conv/fwd_ck.hpp} (73%) create mode 100644 experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp delete mode 100644 experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp similarity index 82% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/args.hpp index 51edf41cba..eba6771964 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/args.hpp @@ -7,26 +7,25 @@ #include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/builder/testing/testing.hpp" -#include "ck_tile/builder/testing/testing_reflect.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/tensor_buffer.hpp" -#include "ck_tile/host/convolution_parameter.hpp" -#include "ck_tile/builder/testing/tensor_initialization.hpp" #include "ck_tile/builder/testing/tensor_descriptor.hpp" -#include "ck_tile/builder/testing/validation.hpp" +#include "ck_tile/host/convolution_parameter.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" /// This file implements common functionality for invoking/testing grouped /// forward convolutions created through the CK Builder API. The main item -/// of it is the ConvArgs structure - which contains a complete description +/// of it is the Args structure - which contains a complete description /// of a convolution operation. /// /// It is not intended that this file contains implementation details for /// actually launching a convolution operation. As this can be done /// through different APIs depending on the kernel (CK, CK Tile, or a /// reference implementation), the code dealing with that is split out -/// into a separate header for each implementation. +/// into a separate header for each implementation. Nor does this file +/// deal with details for defining the data types (`Inputs` and `Outputs`) +/// for different conv directions, that is also split out into separate +/// headers to keep this one small. namespace ck_tile::builder::test { @@ -56,7 +55,7 @@ struct ConvTensorLengths /// /// @see Args template - requires ValidConvSignature && ConvDirectionIsForward + requires ValidConvSignature struct Args { constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim; @@ -204,53 +203,4 @@ struct Args } }; -/// @brief `Inputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Inputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Inputs -{ - void* input; - void* weight; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("input", args.make_input_descriptor(), &Inputs::input); - inspect("weight", args.make_weight_descriptor(), &Inputs::weight); - } -}; - -/// @brief `Outputs` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see Outputs -template - requires ValidConvSignature && ConvDirectionIsForward -struct Outputs -{ - void* output; - - static void reflect(const Args& args, const auto& inspect) - { - inspect("output", args.make_output_descriptor(), &Outputs::output); - } -}; - -/// @brief `init_inputs()` specialization for forward convolution. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// -/// @see alloc_inputs() -template - requires ValidConvSignature && ConvDirectionIsForward -void init_inputs(const Args& args, Inputs inputs) -{ - init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); - init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); -} - } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp new file mode 100644 index 0000000000..ce5811c87a --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight.hpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/error.hpp" + +/// This file deals with the backward weight-specific details of running grouped +/// convolution backwards weight operations. It mainly defines the data +/// structures (`Input` and `Output`), initialization, and validation. Note +/// that for this operation specifically, many of the operations are +/// implemented automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Inputs +{ + void* input; + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("output", args.make_output_descriptor(), &Inputs::output); + } +}; + +/// @brief `Outputs` specialization for backwards weight convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +struct Outputs +{ + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("weight", args.make_weight_descriptor(), &Outputs::weight); + } +}; + +/// @brief `init_inputs()` specialization for backwards convolution. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsBackwardWeight +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp new file mode 100644 index 0000000000..0b1ffeb707 --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/bwd_weight_ck.hpp @@ -0,0 +1,276 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// bwd grouped convolution operations in old CK. The main item is the +/// `run()` function, which is the main implementation used to invoke +/// CK grouped forward convolution kernels. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except +/// with some utility aliases. For that reason, its moved to this detail +/// namespace. +template , + typename Ops = factory::internal::ConvElementwiseOps> +concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv, + const Types::InDataType* p_a, + Types::WeiDataType* p_b, + const Types::OutDataType* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde, + ck::index_t split_k) { + requires ValidConvSignature; + requires ConvDirectionIsBackwardWeight; + + { + conv.MakeArgument(p_a, + p_b, + p_e, + // TODO: Actually support multiple d + {}, + // A lengths/strides + lengths, + strides, + // B lengths/strides + lengths, + strides, + // E lengths/strides + lengths, + strides, + // TODO: Multiple D lengths/strides + {}, + {}, + // strides/dilations/pads + filter, + filter, + filter, + filter, + // element-wise operations. + elementwise_a, + elementwise_b, + elementwise_cde, + split_k) + }; +}; + +} // namespace detail + +/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance; + +/// @brief Concept for checking whether a bwd weight convolution is multiple-d and +/// invoked like old CK. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +template +concept CkConvBwdWeightMultipleDInstance = + detail::CkConvBwdWeightMultipleDInstance; + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +/// @brief `run()` specialization for backward weight convolution and old CK. +/// +/// This overload is specialized for Multiple-D. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + using Types = factory::internal::ConvTensorDataTypes; + + constexpr auto spatial_dim = SIGNATURE.spatial_dim; + + const auto copy = [](const auto& src, auto& dst) { + std::copy(src.begin(), src.end(), dst.begin()); + }; + + const auto to_ck_lengths = [&](const auto& src) { + std::array result; + copy(src, result); + return result; + }; + + const auto to_ck_extent = [&](const auto& extent) { + std::array result; + copy(extent, result); + return result; + }; + + const auto param = args.to_ck_conv_param(); + + const auto input_desc = args.make_input_descriptor(); + const auto weight_desc = args.make_weight_descriptor(); + const auto output_desc = args.make_output_descriptor(); + + auto ck_args = conv.MakeArgument(static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + {}, // TODO + to_ck_lengths(input_desc.get_lengths()), + to_ck_lengths(input_desc.get_strides()), + to_ck_lengths(weight_desc.get_lengths()), + to_ck_lengths(weight_desc.get_strides()), + to_ck_lengths(output_desc.get_lengths()), + to_ck_lengths(output_desc.get_strides()), + {}, // TODO + {}, // TODO + to_ck_extent(param.conv_filter_strides_), + to_ck_extent(param.conv_filter_dilations_), + to_ck_extent(param.input_left_pads_), + to_ck_extent(param.input_right_pads_), + args.a_elementwise_op, + args.b_elementwise_op, + args.cde_elementwise_op, + args.k_batch); + + if(!conv.IsSupportedArgument(ck_args)) + return RunResult::not_supported("invalid ck arguments"); + + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {})); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp similarity index 52% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index a8f6825524..133d7d69b7 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -3,9 +3,8 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" #include @@ -28,9 +27,39 @@ namespace detail { /// namespace. template concept CkTileConvInstance = requires(Conv&) { + requires ValidConvSignature; { Conv::BlockSize() }; }; +template +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + const ck_tile::stream_config s_conf) +{ + using Conv = std::remove_reference_t; + const auto param = args.to_ck_tile_conv_param(); + + ck_tile::GroupedConvHostArgs + host_args(param, input, weight, {}, output, args.k_batch); + + auto kargs = Conv::MakeKernelArgs(host_args); + + const dim3 grids = Conv::GridSize(kargs); + const dim3 blocks = Conv::BlockSize(); + + if(!Conv::IsSupportedArgument(kargs)) + return RunResult::not_supported("unsupported ck_tile arguments"); + + constexpr index_t minimum_occupancy = + Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; + + return RunResult::from_runtime(ck_tile::launch_kernel( + s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +} + } // namespace detail /// @brief Concept for checking whether a convolution is invoked like CK Tile. @@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance; /// @brief `run()` specialization for forward convolution and CK Tile. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template - requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkTileConvInstance auto& conv, + requires ConvDirectionIsForward +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, const ck_tile::stream_config s_conf = {}) { - using Conv = std::remove_reference_t; - const auto param = args.to_ck_tile_conv_param(); + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(inputs.weight), + static_cast(outputs.output), + s_conf); +} - ck_tile::GroupedConvFwdHostArgs<> host_args( - param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch); - - auto kargs = Conv::MakeKernelArgs(host_args); - - const dim3 grids = Conv::GridSize(kargs); - const dim3 blocks = Conv::BlockSize(); - - if(!Conv::IsSupportedArgument(kargs)) - { - std::cout << "Not supported!"; - return std::make_tuple(false, 0.f); - } - - constexpr index_t minimum_occupancy = - Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; - - return std::make_tuple( - true, - ck_tile::launch_kernel( - s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); +/// @brief `run()` specialization for backwards weight convolution and CK Tile. +/// +/// @tparam SIGNATURE Backwards weight convolution signature. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ConvDirectionIsBackwardWeight +[[nodiscard]] RunResult run(CkTileConvInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs, + const ck_tile::stream_config s_conf = {}) +{ + return detail::run(conv, + args, + static_cast(inputs.input), + static_cast(outputs.weight), + static_cast(inputs.output), + s_conf); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp new file mode 100644 index 0000000000..b81892c91e --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd.hpp @@ -0,0 +1,69 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/tensor_initialization.hpp" +#include "ck_tile/builder/testing/testing_reflect.hpp" +#include "ck_tile/builder/testing/conv/args.hpp" + +/// This file deals with the forward-specific details of running grouped +/// convolution forward operations. It mainly defines the data structures +/// (`Input` and `Output`), initialization, and validation. Note that +/// for this operation specifically, many of the operations are implemented +/// automatically via testing_reflect.hpp. + +namespace ck_tile::builder::test { + +/// @brief `Inputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Inputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Inputs +{ + void* input; + void* weight; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("input", args.make_input_descriptor(), &Inputs::input); + inspect("weight", args.make_weight_descriptor(), &Inputs::weight); + } +}; + +/// @brief `Outputs` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see Outputs +template + requires ValidConvSignature && ConvDirectionIsForward +struct Outputs +{ + void* output; + + // See testing_reflect.hpp + static void reflect(const Args& args, const auto& inspect) + { + inspect("output", args.make_output_descriptor(), &Outputs::output); + } +}; + +/// @brief `init_inputs()` specialization for forward convolution. +/// +/// @tparam SIGNATURE Forward convolution signature. +/// +/// @see init_inputs() +template + requires ValidConvSignature && ConvDirectionIsForward +void init_inputs(const Args& args, Inputs inputs) +{ + init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f); + init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp similarity index 73% rename from experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp rename to experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp index f911dca21c..5eca79508c 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_ck.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/fwd_ck.hpp @@ -3,14 +3,14 @@ #pragma once -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/builder/testing/testing.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp" +#include "ck_tile/host/kernel_launch.hpp" #include #include /// This file contains the implementation details for invoking/testing -/// grouped convolution operations in old CK. The main item is the +/// fwd grouped convolution operations in old CK. The main item is the /// `run()` function, which is the main implementation used to invoke /// CK grouped forward convolution kernels. @@ -18,10 +18,9 @@ namespace ck_tile::builder::test { namespace detail { -/// @brief Concept for checking whether this is the reference convolution -/// implementation. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// -/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except +/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except /// with some utility aliases. For that reason, its moved to this detail /// namespace. template > -concept CkConvInstance = requires(Conv& conv, - // TODO: This should be changed depending on IsMultiA etc. - // Currently that is not yet supported elsewhere anyway. - const void* p_a, - const void* p_b, - void* p_e, - std::array lengths, - std::array strides, - std::array filter, - Ops::InElementwiseOp elementwise_a, - Ops::WeiElementwiseOp elementwise_b, - Ops::OutElementwiseOp elementwise_cde) { +concept CkConvFwdInstance = requires(Conv& conv, + // TODO: This should be changed depending on IsMultiA etc. + // Currently that is not yet supported elsewhere anyway. + const void* p_a, + const void* p_b, + void* p_e, + std::array lengths, + std::array strides, + std::array filter, + Ops::InElementwiseOp elementwise_a, + Ops::WeiElementwiseOp elementwise_b, + Ops::OutElementwiseOp elementwise_cde) { + requires ValidConvSignature; + requires ConvDirectionIsForward; + { conv.MakeArgument(p_a, p_b, @@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv, } // namespace detail -/// @brief Concept for checking whether a convolution is invoked like old CK. +/// @brief Concept for checking whether a fwd convolution is invoked like old CK. /// /// This concept is used to tell whether a convolution implementation is /// likely to be an "old CK" implementation - that is, whether we should @@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv, /// - SIGNATURE is the operation signature. /// - Conv is a convolution instance created by the CK Builder API. template -concept CkConvInstance = detail::CkConvInstance; +concept CkConvFwdInstance = detail::CkConvFwdInstance; /// @brief `run()` specialization for forward convolution and old CK. /// /// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @see run() template requires ValidConvSignature && ConvDirectionIsForward -std::tuple run(CkConvInstance auto& conv, +[[nodiscard]] RunResult run(CkConvFwdInstance auto& conv, const Args& args, const Inputs& inputs, const Outputs& outputs, @@ -126,6 +125,9 @@ std::tuple run(CkConvInstance auto& conv, const auto weight_desc = args.make_weight_descriptor(); const auto output_desc = args.make_output_descriptor(); + if(args.k_batch != 1) + return RunResult::not_supported("ck fwd does not support k_batch != 1"); + auto ck_args = conv.MakeArgument(inputs.input, inputs.weight, {}, @@ -147,11 +149,9 @@ std::tuple run(CkConvInstance auto& conv, args.cde_elementwise_op); if(!conv.IsSupportedArgument(ck_args)) - { - std::cout << "invalid argument" << std::endl; - } + return RunResult::not_supported("unsupported ck arguments"); - return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf)); + return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf)); } } // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp new file mode 100644 index 0000000000..169d0741ff --- /dev/null +++ b/experimental/builder/include/ck_tile/builder/testing/conv/reference.hpp @@ -0,0 +1,137 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/builder/testing/testing.hpp" +#include "ck/library/utility/convolution_parameter.hpp" +#include +#include + +/// This file contains the implementation details for invoking/testing +/// grouped convolution operations using the reference implementation. +/// The main item is the `run()` function, which is the primary way to +/// invoke the reference execution mechanism. +/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, +/// but its made specific to the reference implementation, which is +/// invoked in a slightly different way. + +namespace ck_tile::builder::test { + +namespace detail { + +/// @brief Concept for checking whether this is the reference convolution +/// implementation. +/// +/// This concept is used to tell whether a convolution implementation is +/// likely to be the reference implementation - that is, whether we should +/// invoke it like the reference kernel. This is mainly used with `run()` to +/// differentiate which implementation that should be invoked. +/// +/// - SIGNATURE is the operation signature. +/// - Conv is a convolution instance created by the CK Builder API. +/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors. +template +concept RefConvInstance = requires(Conv& conv, + InDataType* input, + WeiDataType* weight, + OutDataType* output, + ck::utils::conv::ConvParam param) { + requires ValidConvSignature; + { conv.Run(input, weight, output, param) }; +}; + +/// @brief Generic `run` implementation for forward/backwards reference kernels. +/// +/// @tparam SIGNATURE The signature of the operation to perform. +/// +/// @return std::tuple - whether the problem is supported and +/// kernel execution time (0.0f for reference). +/// @see run() +template +[[nodiscard]] RunResult +run(RefConvInstance auto& conv, + const Args& args, + InDataType* input, + WeiDataType* weight, + OutDataType* output) +{ + // We don't want to compute the output dims manually, just get + // them via the existing infrastructure + const auto param = args.to_ck_conv_param(); + + // TODO: The reference convolution is currently missing a few features. + // Just throw for now, but regard these as TODO items that should be resolved + // eventually. + + if(!args.make_input_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv"); + + if(!args.make_weight_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv"); + + if(!args.make_output_descriptor().is_packed()) + return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv"); + + conv.Run(input, weight, output, param); + return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime. +} + +} // namespace detail + +/// @brief Concept for checking whether this is the reference convolution +/// forward implementation. +template +concept RefConvFwdInstance = + detail::RefConvInstance && + ConvDirectionIsForward; + +/// @brief `run()` specialization for forward convolution and the reference +/// forward implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template + requires ValidConvSignature && + // TODO: Maybe we can unify this implementation for bwd/weight too? + // for now, just concern outselves with reference and see when the + // rest of the bwd/weight plumbing is there. + ConvDirectionIsForward +[[nodiscard]] RunResult run(RefConvFwdInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, inputs.weight, outputs.output); +} + +/// @brief Concept for checking whether this is the reference convolution +/// backward weight implementation. +template +concept RefConvBwdWeightInstance = + detail::RefConvInstance && + ConvDirectionIsBackwardWeight; + +/// @brief `run()` specialization for forward convolution and the reference +/// backward weight implementation. +/// +/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight. +/// @returns RunResult about how the operation completed (or not). +/// +/// @see run() +template +[[nodiscard]] RunResult run(RefConvBwdWeightInstance auto& conv, + const Args& args, + const Inputs& inputs, + const Outputs& outputs) +{ + return detail::run(conv, args, inputs.input, outputs.weight, inputs.output); +} + +} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp b/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp deleted file mode 100644 index ff276f7c9c..0000000000 --- a/experimental/builder/include/ck_tile/builder/testing/conv_fwd_reference.hpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/builder/testing/conv_fwd.hpp" -#include -#include - -/// This file contains the implementation details for invoking/testing -/// grouped convolution operations using the reference implementation. -/// The main item is the `run()` function, which is the primary way to -/// invoke the reference execution mechanism. -/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`, -/// but its made specific to the reference implementation, which is -/// invoked in a slightly different way. - -namespace ck_tile::builder::test { - -/// @brief Concept for checking whether this is the reference convolution -/// implementation. -/// -/// This concept is used to tell whether a convolution implementation is -/// likely to be the reference implementation - that is, whether we should -/// invoke it like the reference kernel. This is mainly used with `run()` to -/// differentiate which implementation that should be invoked. -/// -/// - SIGNATURE is the operation signature. -/// - Conv is a convolution instance created by the CK Builder API. -template -concept RefConvInstance = requires(Conv& conv, - const void* input, - const void* weight, - void* output, - ck::utils::conv::ConvParam param) { - { conv.Run(input, weight, output, param) }; -}; - -/// @brief `run()` specialization for forward convolution and the reference -/// implementation. -/// -/// @tparam SIGNATURE Forward convolution signature. -/// @throws std::runtime_error if the arguments weren't actually valid for the -/// operation. This should be caught and reported by the testing framework. -/// -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f for reference). -/// @see run() -template - requires ValidConvSignature && - // TODO: Maybe we can unify this implementation for bwd/weight too? - // for now, just concern outselves with reference and see when the - // rest of the bwd/weight plumbing is there. - ConvDirectionIsForward -std::tuple run(RefConvInstance auto& conv, - const Args& args, - const Inputs& inputs, - const Outputs& outputs) -{ - // We don't want to compute the output dims manually, just get - // them via the existing infrastructure - const auto param = args.to_ck_conv_param(); - - // TODO: The reference convolution is currently missing a few features. - // Just throw for now, but regard these as TODO items that should be resolved - // eventually. - - if(!args.make_input_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_weight_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - if(!args.make_output_descriptor().is_packed()) - { - std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl; - return std::make_tuple(false, 0.0f); - } - - conv.Run(inputs.input, inputs.weight, outputs.output, param); - return std::make_tuple(true, 0.0f); -} - -} // namespace ck_tile::builder::test diff --git a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp index 2976e6c14b..35fc1f4ee8 100644 --- a/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/tensor_initialization.hpp @@ -12,6 +12,7 @@ #include "ck_tile/builder/conv_signature_concepts.hpp" #include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp" #include "ck_tile/builder/testing/type_traits.hpp" +#include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/host/host_tensor.hpp" #include "ck/utility/data_type.hpp" diff --git a/experimental/builder/include/ck_tile/builder/testing/testing.hpp b/experimental/builder/include/ck_tile/builder/testing/testing.hpp index e61d7c4da5..307871b47a 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing.hpp @@ -3,7 +3,11 @@ #pragma once +#include #include +#include +#include +#include #include "ck_tile/builder/testing/tensor_descriptor.hpp" #include "ck_tile/builder/testing/tensor_buffer.hpp" @@ -288,6 +292,57 @@ ValidationReport validate(const Args& args, Outputs actual, Outputs expected) = delete; +/// @brief This structure represents the result of a run operation. +/// +/// The structure contains multiple fields with information about +/// how the operation completed (or not). See those for more info. +struct RunResult +{ + /// If this value is not set to `std::nullopt`, there was a problem + /// while running the algorithm. In this case, the outputs are not + /// valid (though may be partially or completely overwritten), and + /// the optional contains a short debug message that indicates the + /// problem. + std::optional error = std::nullopt; + + /// The runtime of the kernel in milliseconds, if measured. Whether the + /// runtime is measured at all depends on the stream configuration + /// passed to run(). 0 if not measured or if there was an error. This + /// value is averaged over the total amount of runs actually done. Again, + /// this is usually configured via the stream config. + float runtime = 0.f; + + /// @brief Utility function for constructing a RunResult from an unsupported operation. + /// + /// @param msg A short debug message that will be included in the result. + constexpr static RunResult not_supported(std::string_view msg) + { + return RunResult{.error = std::string(msg)}; + } + + /// @brief Utility function for constructing a RunResult from an average runtime, + /// indicating a successful operation. + /// + /// @param runtime The runtime of the kernel in milliseconds. + constexpr static RunResult from_runtime(const float runtime) + { + return RunResult{.runtime = runtime}; + } + + /// @brief Returns whether this algorithm executed successfully. + /// + /// In this case there should be no message in `error`. + bool is_supported() const { return !this->error.has_value(); } +}; + +inline std::ostream& operator<<(std::ostream& os, const RunResult& result) +{ + if(result.error.has_value()) + return os << "invalid run (" << result.error.value() << ")"; + else + return os << "successful run (" << result.runtime << " ms)"; +} + /// @brief Invoke a device operation created by CK Builder. /// /// This is the main function used to invoke a particular device operation @@ -318,13 +373,14 @@ ValidationReport validate(const Args& args, /// @param outputs The output tensor data. The contents will be overwritten by /// this function. /// @param s_conf Stream config used to launch kernel. -/// @return std::tuple - whether the problem is supported and -/// kernel execution time (0.0f if s_conf time_kernel is false). +/// @returns RunResult about how the operation completed (or not). /// /// @note This function is explicitly deleted to generate compile errors /// for missing implementations. +/// +/// @see RunResult template -std::tuple run(Operation& operation, +[[nodiscard]] RunResult run(Operation& operation, const Args& args, const Inputs& inputs, const Outputs& outputs, diff --git a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp index 81d5b7a6f5..076b5e9751 100644 --- a/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/testing_reflect.hpp @@ -5,6 +5,8 @@ #include +#include "ck_tile/builder/testing/testing.hpp" + /// testing.hpp requires developers of a type of SIGNATURE to implement /// quite a lot of functionality for each SIGNATURE. For example, next /// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index 9890563859..73a682f10c 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/ck/test_ckb_conv_fwd_3d_fp16.cpp conv/ck/test_ckb_conv_fwd_3d_fp32.cpp conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp - ) +) target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility) set(BWD_WEIGHT_TESTS diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp index 4ad97209e5..a3f4a988ef 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_bwd_weight_xdl_cshuffle_v3.cpp @@ -1,23 +1,30 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/host/device_prop.hpp" +#include "testing_utils.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; + using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1, .direction = ckb::ConvDirection::BACKWARD_WEIGHT, .data_type = ckb::DataType::BF16, .accumulation_data_type = ckb::DataType::FP32, - .input = {.config = {.layout = NGCW}}, + .input = {.config = {.layout = GNWC}}, .weight = {.config = {.layout = GKXC}}, - .output = {.config = {.layout = NGKW}}}; + .output = {.config = {.layout = GNWK}}}; constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{} @@ -30,14 +37,58 @@ constexpr auto ALGORITHM = using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; +using Reference = ckb::ConvBuilder::Instance; + TEST(BwdWeight_1DBf16_CShuffle_V3, Create) { const auto expected_transfer_parameters = to_string(ALGORITHM); cku::run_test({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3", expected_transfer_parameters, "Filter1x1Stride1Pad0", - "NGCW,GKXC,NGKW", + "GNWC,GKXC,GNWK", "PassThrough,PassThrough,PassThrough", "Intrawave", "v2"}); } + +TEST(BwdWeight_1DBf16_CShuffle_V3, Execution) +{ + if(!ck_tile::get_device_name().starts_with("gfx9")) + { + // Note: XDL kernel + GTEST_SKIP() << "unsupported architecture"; + } + + ckt::Args args = { + .lengths = + { + .batch_size = 16, + .groups = 1, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 64}, + .filter = {.width = 1}, + }, + .filter_strides = {.width = 1}, + .filter_dilation = {.width = 1}, + .input_left_pad = {.width = 0}, + .input_right_pad = {.width = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp index 3e5e39191e..51bc45c29b 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_2d_fp16.cpp @@ -4,8 +4,9 @@ #include "utils/ckb_conv_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd_ck.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, @@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create) "MNKPadding"}); } -TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) +TEST(Fwd2DFp16_CShufV3_GNHWC, Execution) { if(!ck_tile::get_device_name().starts_with("gfx9")) { + // Note: XDL kernel GTEST_SKIP() << "unsupported architecture"; } @@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 292d852b91..60dc45545f 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -1,35 +1,47 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT +#include "ck_tile/builder/testing/conv/bwd_weight.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" +#include "ck_tile/host/device_prop.hpp" #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" +#include "testing_utils.hpp" -namespace { +namespace ckb = ck_tile::builder; +namespace ckt = ck_tile::builder::test; +namespace cku = ck_tile::builder::test_utils; -using namespace ck_tile::builder::test_utils; +using enum ck_tile::builder::TensorLayout; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; -TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC) +constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2, + .direction = ckb::ConvDirection::BACKWARD_WEIGHT, + .data_type = ckb::DataType::FP16, + .accumulation_data_type = ckb::DataType::FP32, + .input = {.config = {.layout = NHWGC}}, + .weight = {.config = {.layout = GKYXC}}, + .output = {.config = {.layout = NHWGK}}}; + +constexpr auto ALGORITHM = + cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{} + .with_tile_specializations(ckb::TileConvSpecialization::DEFAULT) + .with_tile_thread_block(cku::TileThreadBlock_64x64x64) + .with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave) + .with_tile_transfer(cku::TileTransfer_4x4x4) + .with_tile_optimizations(ckt::TileOptimizations{ + .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); + +using Builder = ckb::ConvBuilder; +using Instance = Builder::Instance; + +using Reference = ckb::ConvBuilder::Instance; + +TEST(BwdWeight_2D_FP16_NHWGC, Create) { - constexpr ConvSignature BwdWeightConvSignature{ - .spatial_dim = 2, - .direction = ConvDirection::BACKWARD_WEIGHT, - .data_type = DataType::FP16, - .accumulation_data_type = DataType::FP32, - .input = {.config = {.layout = TensorLayout::NHWGC}}, - .weight = {.config = {.layout = TensorLayout::GKYXC}}, - .output = {.config = {.layout = TensorLayout::NHWGK}}}; - - constexpr auto BwdWeightConvAlgorithm = - ConvAlgorithm_Tile_GroupedConvolutionKernel{} - .with_tile_specializations(TileConvSpecialization::DEFAULT) - .with_tile_thread_block(TileThreadBlock_64x64x64) - .with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave) - .with_tile_transfer(TileTransfer_4x4x4) - .with_tile_optimizations(TileOptimizations{ - .num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false}); - - using Builder = ConvBuilder; - run_ck_tile_test({ + cku::run_ck_tile_test({ "grouped_convolution_backward_weight", "fp16", "NHWGC_GKYXC_NHWGK", @@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_ }); } -} // namespace +TEST(BwdWeight_2D_FP16_NHWGC, Execution) +{ + ckt::Args args = { + .lengths = + { + .batch_size = 2, + .groups = 4, + .input_channels = 32, + .output_channels = 48, + .image = {.width = 32, .height = 56}, + .filter = {.width = 3, .height = 3}, + }, + .filter_strides = {.width = 1, .height = 1}, + .filter_dilation = {.width = 1, .height = 1}, + .input_left_pad = {.width = 0, .height = 0}, + .input_right_pad = {.width = 0, .height = 0}, + .a_elementwise_op = {}, + .b_elementwise_op = {}, + .cde_elementwise_op = {}, + }; + + auto inputs = ckt::alloc_inputs(args); + auto outputs = ckt::alloc_outputs(args); + auto reference = ckt::alloc_outputs(args); + + ckt::init_inputs(args, inputs.get()); + + auto conv = Instance{}; + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); + + auto ref_conv = Reference{}; + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); + + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); +} diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp index 128744dcc6..650c217b71 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_e2e.cpp @@ -4,8 +4,8 @@ #include "utils/ckb_conv_tile_test_configs.hpp" #include "utils/ckb_conv_test_utils.hpp" #include "utils/conv_algorithm_type_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/host/device_prop.hpp" #include "testing_utils.hpp" @@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; namespace cku = ck_tile::builder::test_utils; +using ck_tile::test::MatchesReference; +using ck_tile::test::SuccessfulRun; + constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 2, .direction = ckb::ConvDirection::FORWARD, @@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd) ckt::init_inputs(args, inputs.get()); auto conv = Instance{}; - ckt::run(conv, args, inputs.get(), outputs.get()); + EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun()); auto ref_conv = Reference{}; - ckt::run(ref_conv, args, inputs.get(), reference.get()); + EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun()); - EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get())); + EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get())); } diff --git a/experimental/builder/test/test_testing_utils.cpp b/experimental/builder/test/test_testing_utils.cpp index 43bbbd69eb..100122eef3 100644 --- a/experimental/builder/test/test_testing_utils.cpp +++ b/experimental/builder/test/test_testing_utils.cpp @@ -5,11 +5,14 @@ #include "testing_utils.hpp" +namespace ckt = ck_tile::builder::test; + using ck_tile::test::HipError; using ck_tile::test::HipSuccess; using ck_tile::test::InstanceMatcher; using ck_tile::test::InstanceSet; using ck_tile::test::StringEqWithDiff; +using ck_tile::test::SuccessfulRun; TEST(InstanceSet, FromFactory) { @@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic) EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue))); EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue))); } + +TEST(RunResultMatcher, Basic) +{ + EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun()); + EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun())); +} + +TEST(RunResultMatcher, ExplainMatchResult) +{ + testing::StringMatchResultListener listener; + EXPECT_TRUE(!ExplainMatchResult( + SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener)); + EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error")); +} diff --git a/experimental/builder/test/testing_utils.cpp b/experimental/builder/test/testing_utils.cpp index b60c35333e..e9677e5940 100644 --- a/experimental/builder/test/testing_utils.cpp +++ b/experimental/builder/test/testing_utils.cpp @@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const return ::testing::MakeMatcher(new HipStatusMatcher(error)); } +bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const +{ + if(actual.error.has_value() && listener) + *listener << "run failed: " << actual.error.value(); + + return actual.is_supported(); +} + +void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; } + +void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; } + +::testing::Matcher SuccessfulRun() +{ + return ::testing::MakeMatcher(new RunResultMatcher()); +} + } // namespace ck_tile::test diff --git a/experimental/builder/test/testing_utils.hpp b/experimental/builder/test/testing_utils.hpp index b84d53b6df..55de133a2a 100644 --- a/experimental/builder/test/testing_utils.hpp +++ b/experimental/builder/test/testing_utils.hpp @@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface /// @param error The error to expect. ::testing::Matcher HipError(hipError_t error); +/// @brief RunResult matcher +/// +/// `ckt::run` returns a RunResult which indicates whether there was any +/// problem while running the algorithm. This matcher is used to match those +/// values. +struct RunResultMatcher : public ::testing::MatcherInterface +{ + bool MatchAndExplain(builder::test::RunResult actual, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override; +}; + +/// @brief Construct a Google Test matcher that checks that a ckt::run result +/// was successful. +::testing::Matcher SuccessfulRun(); + template struct ReferenceOutputMatcher : public ::testing::MatcherInterface> @@ -180,6 +197,21 @@ struct ReferenceOutputMatcher if(listener->IsInterested() && !errors.empty()) { *listener << errors.size() << " tensors failed to validate"; + + for(const auto& e : errors) + { + *listener << "\n - " << e.tensor_name << ": "; + + if(e.is_all_zero()) + *listener << "all elements in actual and expected tensors are zero"; + else + { + // Round to 2 digits + const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f; + *listener << e.wrong_elements << "/" << e.total_elements + << " incorrect elements (~" << percentage << "%)"; + } + } } return errors.empty(); diff --git a/experimental/builder/test/unit_conv_fwd_testing.cpp b/experimental/builder/test/unit_conv_fwd_testing.cpp index be95a29a2d..9fc07568b4 100644 --- a/experimental/builder/test/unit_conv_fwd_testing.cpp +++ b/experimental/builder/test/unit_conv_fwd_testing.cpp @@ -3,7 +3,7 @@ #include "impl/conv_signature_types.hpp" #include "testing_utils.hpp" -#include "ck_tile/builder/testing/conv_fwd.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" #include "ck_tile/builder/testing/tensor_foreach.hpp" #include #include diff --git a/experimental/builder/test/unit_validation.cpp b/experimental/builder/test/unit_validation.cpp index a83d034ac2..0dad8593fb 100644 --- a/experimental/builder/test/unit_validation.cpp +++ b/experimental/builder/test/unit_validation.cpp @@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect) testing::StringMatchResultListener listener; EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener)); - EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate")); + EXPECT_THAT(listener.str(), + StringEqWithDiff( // + "1 tensors failed to validate\n" + " - a: 625/625 incorrect elements (~100%)")); } diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc index 4b4c144428..ae451caec0 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_includes.inc @@ -1,5 +1,6 @@ #include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ckb = ck_tile::builder; namespace ckt = ck_tile::builder::test; diff --git a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc index 6b8024fa93..016ef3e653 100644 --- a/experimental/grouped_convolution_tile_instances/instances/instance_run.inc +++ b/experimental/grouped_convolution_tile_instances/instances/instance_run.inc @@ -2,8 +2,6 @@ using Builder = ckb::ConvBuilder; using Instance = Builder::Instance; -auto conv = Instance{}; -bool is_supported; -float avg_time; -std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf); -return std::make_tuple(is_supported, avg_time, conv.GetInstanceString()); +auto conv = Instance{}; +ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf); +return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString()); diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index e58c884729..9f7227a699 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -9,8 +9,9 @@ #include "grouped_convolution_signatures.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" -#include "ck_tile/builder/testing/conv_fwd_reference.hpp" +#include "ck_tile/builder/testing/conv/fwd.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" +#include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/builder/conv_builder.hpp" namespace ck_tile::builder::profiling { @@ -113,8 +114,8 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, auto reference = ckt::alloc_outputs(args); using ReferenceInstance = typename ckb::ConvBuilder::Instance; - auto ref_conv = ReferenceInstance{}; - ckt::run(ref_conv, args, inputs, reference.get()); + auto ref_conv = ReferenceInstance{}; + [[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get()); [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); diff --git a/profiler/include/profiler/grouped_convolution_signatures.hpp b/profiler/include/profiler/grouped_convolution_signatures.hpp index 5103b0f235..0f87e283bb 100644 --- a/profiler/include/profiler/grouped_convolution_signatures.hpp +++ b/profiler/include/profiler/grouped_convolution_signatures.hpp @@ -6,7 +6,7 @@ #include #include "../../experimental/builder/test/impl/conv_signature_types.hpp" -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" namespace ck_tile::builder::profiling { diff --git a/profiler/src/profile_grouped_conv_fwd_tile.cpp b/profiler/src/profile_grouped_conv_fwd_tile.cpp index 8023dcf2f6..1a1e8b769a 100644 --- a/profiler/src/profile_grouped_conv_fwd_tile.cpp +++ b/profiler/src/profile_grouped_conv_fwd_tile.cpp @@ -6,7 +6,7 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp index c04a15ec98..068811cf00 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -7,7 +7,7 @@ #include #include -#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp" +#include "ck_tile/builder/testing/conv/ck_tile.hpp" #include "ck_tile/host/device_prop.hpp" #include "profiler/grouped_convolution_forward_tile_algs.hpp" From c190d8d61f2ea44a0d04b8c6706434098ca0c691 Mon Sep 17 00:00:00 2001 From: Johannes Graner Date: Tue, 27 Jan 2026 09:49:42 +0100 Subject: [PATCH 21/28] [CK tests] Extend conv GPU reference (#3539) * test_convnd_fwd * test_convnd_bwd_data * test_conv_bwd_data_scale * test_grouped_convnd_fwd_clamp * test_grouped_convnd_fwd_scale * multiple A/B tensors and D tensor for fwd GPU ref * test_grouped_convnd_fwd_scaleadd_ab * test_grouped_convnd_fwd_bias_clamp * test_grouped_convnd_fwd_bilinear * test_grouped_convnd_fwd_gk_bias_clamp * Extend GPU reference to enable batchnorm epilogue * test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp * test_grouped_conv_bwd_data_bilinear * test_grouped_convnd_bwd_weight_bilinear * Add missing template instantiation * Perform operations in float in reference * Slightly increase tolerance for batchnorm profiler * Revert "Slightly increase tolerance for batchnorm profiler" This reverts commit a3b247522902c712930369f466c376a6430f4f67. * Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp" This reverts commit 6da4576060215e1d3e0e79ca355c340d3546363c. * Revert "Extend GPU reference to enable batchnorm epilogue" This reverts commit e2f75fa10e80740eddb7a46f0a51aaac74b8f1a5. * Clarify variable names * Refactor elementwise ops into helper functions * Make helpers C++17-compatible --- .../element/unary_element_wise_operation.hpp | 23 + .../gpu/naive_conv_bwd_data_gpu.hpp | 465 ++++++++++++----- .../gpu/naive_conv_bwd_weight_gpu.hpp | 475 ++++++++++++++---- .../gpu/naive_conv_fwd_gpu.hpp | 468 +++++++++++++---- .../gpu/naive_conv_utils.hpp | 117 ++++- .../profiler/profile_conv_bwd_data_impl.hpp | 56 ++- .../profiler/profile_conv_fwd_impl.hpp | 45 +- ...ofile_grouped_conv_fwd_bias_clamp_impl.hpp | 73 ++- ...profile_grouped_conv_fwd_bilinear_impl.hpp | 59 ++- ...ile_grouped_conv_fwd_outelementop_impl.hpp | 77 ++- test/convnd_bwd_data/convnd_bwd_data_xdl.cpp | 2 +- test/convnd_fwd/convnd_fwd_xdl.cpp | 2 +- test/gpu_reference/CMakeLists.txt | 3 + test/gpu_reference/gpu_reference_utils.hpp | 225 +++++++++ .../test_gpu_reference_conv_fwd_multi_abd.cpp | 319 ++++++++++++ .../test_grouped_conv_bwd_data_bilinear.cpp | 81 +-- .../test_grouped_conv_bwd_data_scale.cpp | 51 +- ...est_grouped_convnd_bwd_weight_bilinear.cpp | 83 +-- .../test_grouped_convnd_fwd_bilinear.cpp | 4 +- .../test_grouped_convnd_fwd_scaleadd_ab.cpp | 52 +- .../test_grouped_convnd_fwd_bias_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_gk_bias_clamp.cpp | 2 +- .../test_grouped_convnd_fwd_scale.cpp | 4 +- 24 files changed, 2217 insertions(+), 473 deletions(-) create mode 100644 test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 6cd7b3d9f6..31047c03b2 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -1631,6 +1631,13 @@ struct ConvInvscale e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float / scale_in_ / scale_wei_ / scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1656,6 +1663,13 @@ struct ConvScale e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + e = type_convert(c_float * scale_in_ * scale_wei_ * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; @@ -1683,6 +1697,15 @@ struct ConvScaleRelu e = type_convert(x * scale_out_); }; + template <> + __host__ __device__ void operator()(f8_t& e, const f8_t& c) const + { + const float c_float = type_convert(c); + float x; + Relu{}.template operator()(x, c_float * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + float scale_in_; float scale_wei_; float scale_out_; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp index aecf519c10..5210265cef 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp @@ -10,49 +10,55 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward data convolution kernel working with packed (contiguous) tensors -// Computes gradients w.r.t. input from output gradients and weights -// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], -// output[G][N][K][spatial] +// Optimized backward data convolution kernel working with packed (contiguous) tensors with +// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes +// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial] template -__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - const OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in, + const WeiDataType* const* __restrict__ p_weis, + const OutDataType* const* __restrict__ p_outs, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t x = 0; x < X; ++x) { @@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t wo = w_tmp / stride_x; if(wo >= 0 && wo < Wo) { - const OutDataType* out_gnk = out_gn; - const WeiDataType* wei_gkc = wei_g + c * wei_stride_c; + // Pointers at current filter position + const OutDataType* output_grad_g_n_k = output_grad_g_n; + const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c; for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnk[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkc[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_g_n_k, + p_outs + 1, + g * out_stride_g + n * out_stride_n, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_g_k_c, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } } } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val; } } @@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t y = 0; y < Y; ++y) { @@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkh = out_gn + ho * out_stride_h; - const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h; + const WeiDataType* weight_at_c_y = + weight_g + c * wei_stride_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, out_gnkh[k * out_stride_k + wo]); - wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation with extra + // A tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_outs + 1, + g * out_stride_g + n * out_stride_n + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op(in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + hi * p_d_strides[0][3] + + wi * p_d_strides[0][4]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] = in_val; } @@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n; - const WeiDataType* wei_g = p_wei + g * wei_stride_g; + float acc = 0.0f; + // Base pointers for current group and batch + const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n; + const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g; for(index_t z = 0; z < Z; ++z) { @@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t do_idx = d_tmp / stride_z; if(do_idx >= 0 && do_idx < Do) { - const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d; - const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z; + // Pointers at current spatial depth + const OutDataType* output_grad_at_d = + output_grad_g_n + do_idx * out_stride_d; + const WeiDataType* weight_at_c_z = + weight_g + c * wei_stride_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { @@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, long_index_t ho = h_tmp / stride_y; if(ho >= 0 && ho < Ho) { - const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; + const WeiDataType* weight_at_c_z_y = + weight_at_c_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { @@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, { for(index_t k = 0; k < K; ++k) { - out_op(out_val, - out_gnkdh[k * out_stride_k + wo]); - wei_op(wei_val, - wei_gkczy[k * wei_stride_k + x]); + // Handle output gradient element-wise operation + // with extra A tensors + detail::apply_multi_tensor_elementwise_op< + NumAExtra>(out_val, + out_op, + output_grad_at_d_h, + p_outs + 1, + g * out_stride_g + + n * out_stride_n + + do_idx * out_stride_d + + ho * out_stride_h, + k * out_stride_k + wo); + + // Handle weight element-wise operation with + // extra B tensors + detail::apply_multi_tensor_elementwise_op< + NumBExtra>( + wei_val, + wei_op, + weight_at_c_z_y, + p_weis + 1, + g * wei_stride_g + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + k * wei_stride_k + x); + acc += type_convert(out_val) * type_convert(wei_val); } @@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in, } } - InDataType result = type_convert(acc); - in_op(in_val, result); + detail::apply_d_tensor_elementwise_op( + in_val, + in_op, + acc, + p_ds, + p_d_strides, + g, + n, + c, + di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]); + p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d + hi * in_stride_h + wi] = in_val; } } } -// GPU reference backward data convolution - takes ConvParam directly -template -void naive_conv_bwd_data(TIn* p_in, - const TWei* p_wei, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility +void naive_conv_bwd_data_multi_abd( + TIn* p_in, + const std::array& p_weis, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in, // Allocate packed buffers SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); - TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + + std::vector out_packed_bufs; + out_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + out_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_outs_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_outs_packed[i] = static_cast(out_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in, // Pack output and weight tensors to contiguous layout (inputs to bwd data) constexpr int block_size = 256; - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TOut** d_outs_ptrs = static_cast(outs_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs, + p_outs_packed.data(), + (NumAElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in, if(ndim == 1) { - naive_conv_bwd_data_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_data_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in, } else // 3D { - naive_conv_bwd_data_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> + naive_conv_bwd_data_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> <<>>(p_in_packed, - p_wei_packed, - p_out_packed, + d_weis_ptrs, + d_outs_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_data - now a zero-overhead wrapper +template +inline void naive_conv_bwd_data(TIn* p_in, + const TWei* p_wei, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_weis = {p_wei}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in, + p_weis, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp index f46b072baa..8cee2e2b77 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp @@ -10,49 +10,58 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized backward weight convolution kernel working with packed (contiguous) tensors +// Optimized backward weight convolution kernel working with packed (contiguous) tensors with +// multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial], // weight_grad[G][K][C][filter] // Computes gradient with respect to weights template -__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in, - WeiDataType* __restrict__ p_wei_grad, - const OutDataType* __restrict__ p_out_grad, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void +naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins, + WeiDataType* __restrict__ p_wei_grad, + const OutDataType* const* __restrict__ p_out_grads, + const DDataType* const* __restrict__ p_ds, + const index_t* const* __restrict__ p_d_strides, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gn[wi]); - out_op(out_val, out_gn_k[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_n_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle output gradient element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_n_k, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val; } } @@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h; + // Pointers at current spatial height + const InDataType* input_at_h = input_at_n_c + hi * in_stride_h; + const OutDataType* output_grad_at_h = + output_grad_at_n_k + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - out_op(out_val, out_gn_kh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra B + // tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } } @@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op(wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + y * p_d_strides[0][3] + + x * p_d_strides[0][4]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y + x] = wei_val; } @@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in const index_t k = remaining % K; const index_t g = remaining / K; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g; - const OutDataType* out_grad = p_out_grad + g * out_stride_g; + float acc = 0.0f; + // Base pointers for current group + const InDataType* input_g = p_ins[0] + g * in_stride_g; + const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g; // Loop over batch and output positions for(index_t n = 0; n < N; ++n) { - const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c; - const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k; + // Pointers at current batch and input channel + const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c; + const OutDataType* output_grad_at_n_k = + output_grad_g + n * out_stride_n + k * out_stride_k; for(index_t do_idx = 0; do_idx < Do; ++do_idx) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_n_c + di * in_stride_d; + const OutDataType* output_grad_at_d = + output_grad_at_n_k + do_idx * out_stride_d; for(index_t ho = 0; ho < Ho; ++ho) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const OutDataType* output_grad_at_d_h = + output_grad_at_d + ho * out_stride_h; for(index_t wo = 0; wo < Wo; ++wo) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - out_op(out_val, out_gn_kdh[wo]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle output gradient element-wise operation with extra + // B tensors + detail::apply_multi_tensor_elementwise_op( + out_val, + out_op, + output_grad_at_d_h, + p_out_grads + 1, + g * out_stride_g + n * out_stride_n + k * out_stride_k + + do_idx * out_stride_d + ho * out_stride_h, + wo); + acc += type_convert(out_val) * type_convert(in_val); } @@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in } } - WeiDataType result = type_convert(acc); - wei_op(wei_val, result); + detail::apply_d_tensor_elementwise_op( + wei_val, + wei_op, + acc, + p_ds, + p_d_strides, + g, + k, + c, + z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]); + p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z + y * wei_stride_y + x] = wei_val; } } } -// GPU reference backward weight convolution - takes ConvParam directly -template -void naive_conv_bwd_weight(const TIn* p_in, - TWei* p_wei_grad, - const TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility +void naive_conv_bwd_weight_multi_abd( + const std::array& p_ins, + TWei* p_wei_grad, + const std::array& p_outs, + const std::array& p_ds, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in, out_total *= l; // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); - SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut)); + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei)); + + std::vector out_grad_packed_bufs; + out_grad_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut)); + } + + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); TWei* p_wei_grad_packed = static_cast(wei_grad_packed_buf.GetDeviceBuffer()); - TOut* p_out_grad_packed = static_cast(out_grad_packed_buf.GetDeviceBuffer()); + + std::array p_out_grads_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_out_grads_packed[i] = static_cast(out_grad_packed_bufs[i].GetDeviceBuffer()); + } // Compute strides and allocate device arrays for pack/unpack std::vector in_strides = compute_conv_tensor_strides(in_lengths, ndim); @@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in, // Pack input and output_grad tensors to contiguous layout (inputs to bwd weight) constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total); + + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_outs[i], + p_out_grads_packed[i], + d_out_lengths, + d_out_strides, + dim_count, + out_total); + } + + // Prepare D tensor stride arrays on device + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TOut** d_out_grads_ptrs = static_cast(out_grads_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs, + p_out_grads_packed.data(), + (NumBElementwise + 1) * sizeof(TOut*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in, if(ndim == 1) { - naive_conv_bwd_weight_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else if(ndim == 2) { - naive_conv_bwd_weight_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in, } else // 3D { - naive_conv_bwd_weight_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, + naive_conv_bwd_weight_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, p_wei_grad_packed, - p_out_grad_packed, + d_out_grads_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, G, N, K, @@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_bwd_weight - now a zero-overhead wrapper +template +inline void +naive_conv_bwd_weight(const TIn* p_in, + TWei* p_wei_grad, + const TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_outs = {p_out}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_wei_grad, + p_outs, + p_ds, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp index 131b632a25..7bf9b49998 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp @@ -10,48 +10,56 @@ #include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include namespace ck { namespace ref { -// Optimized convolution kernel working with packed (contiguous) tensors +// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support // Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], // output[G][N][K][spatial] template -__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, - const WeiDataType* __restrict__ p_wei, - OutDataType* __restrict__ p_out, - index_t G, - index_t N, - index_t K, - index_t C, - index_t Di, - index_t Hi, - index_t Wi, - index_t Z, - index_t Y, - index_t X, - index_t Do, - index_t Ho, - index_t Wo, - index_t stride_z, - index_t stride_y, - index_t stride_x, - index_t dilation_z, - index_t dilation_y, - index_t dilation_x, - index_t pad_z, - index_t pad_y, - index_t pad_x, - InElementOp in_op, - WeiElementOp wei_op, - OutElementOp out_op) +__global__ void naive_conv_fwd_packed_multi_abd( + const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra) + const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra) + const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers + const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays + OutDataType* __restrict__ p_out, + index_t G, + index_t N, + index_t K, + index_t C, + index_t Di, + index_t Hi, + index_t Wi, + index_t Z, + index_t Y, + index_t X, + index_t Do, + index_t Ho, + index_t Wo, + index_t stride_z, + index_t stride_y, + index_t stride_x, + index_t dilation_z, + index_t dilation_y, + index_t dilation_x, + index_t pad_z, + index_t pad_y, + index_t pad_x, + InElementOp in_op, + WeiElementOp wei_op, + OutElementOp out_op) { const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x; const long_index_t num_threads = blockDim.x * gridDim.x; @@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gc = in_g + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gc[wi]); - wei_op(wei_val, wei_gkc[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_c, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_c, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val; } } @@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gnch = in_gnc + hi * in_stride_h; - const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y; + // Pointers at current spatial height and filter Y position + const InDataType* input_at_h = input_at_c + hi * in_stride_h; + const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gnch[wi]); - wei_op(wei_val, wei_gkcy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } } @@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op(out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + ho * p_d_strides[0][3] + + wo * p_d_strides[0][4]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] = out_val; } @@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, const index_t n = remaining % N; const index_t g = remaining / N; - float acc = 0.0f; - const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n; - const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k; + float acc = 0.0f; + // Base pointers for current group, batch, and output channel + const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n; + const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k; for(index_t c = 0; c < C; ++c) { - const InDataType* in_gnc = in_gn + c * in_stride_c; - const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c; + // Pointers at current input channel + const InDataType* input_at_c = input_g_n + c * in_stride_c; + const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c; for(index_t z = 0; z < Z; ++z) { long_index_t di = do_idx * stride_z + z * dilation_z - pad_z; if(di >= 0 && di < Di) { - const InDataType* in_gncd = in_gnc + di * in_stride_d; - const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z; + // Pointers at current spatial depth + const InDataType* input_at_d = input_at_c + di * in_stride_d; + const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z; for(index_t y = 0; y < Y; ++y) { long_index_t hi = ho * stride_y + y * dilation_y - pad_y; if(hi >= 0 && hi < Hi) { - const InDataType* in_gncdh = in_gncd + hi * in_stride_h; - const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y; + // Pointers at current spatial depth and height + const InDataType* input_at_d_h = input_at_d + hi * in_stride_h; + const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y; for(index_t x = 0; x < X; ++x) { long_index_t wi = wo * stride_x + x * dilation_x - pad_x; if(wi >= 0 && wi < Wi) { - in_op(in_val, in_gncdh[wi]); - wei_op(wei_val, wei_gkczy[x]); + // Handle input element-wise operation with extra A tensors + detail::apply_multi_tensor_elementwise_op( + in_val, + in_op, + input_at_d_h, + p_ins + 1, + g * in_stride_g + n * in_stride_n + c * in_stride_c + + di * in_stride_d + hi * in_stride_h, + wi); + + // Handle weight element-wise operation with extra B tensors + detail::apply_multi_tensor_elementwise_op( + wei_val, + wei_op, + weight_at_z_y, + p_weis + 1, + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + + z * wei_stride_z + y * wei_stride_y, + x); + acc += type_convert(in_val) * type_convert(wei_val); } @@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in, } } - OutDataType result = type_convert(acc); - out_op(out_val, result); + detail::apply_d_tensor_elementwise_op( + out_val, + out_op, + acc, + p_ds, + p_d_strides, + g, + n, + k, + do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]); + p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d + ho * out_stride_h + wo] = out_val; } } } -// GPU reference convolution - takes ConvParam directly -template -void naive_conv_fwd(const TIn* p_in, - const TWei* p_wei, - TOut* p_out, - const ck::utils::conv::ConvParam& conv_param, - InElementwiseOperation in_element_op = InElementwiseOperation{}, - WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, - OutElementwiseOperation out_element_op = OutElementwiseOperation{}, - hipStream_t stream = nullptr) + typename OutElementwiseOperation, + typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility +void naive_conv_fwd_multi_abd( + const std::array& p_ins, + const std::array& p_weis, + const std::array& p_ds, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + [[maybe_unused]] const std::array, NumDElementwise>& d_lengths, + const std::array, NumDElementwise>& d_strides, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) { const auto ndim = conv_param.num_dim_spatial_; @@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in, for(auto l : out_lengths) out_total *= l; - // Allocate packed buffers - SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn)); - SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei)); + // Allocate packed buffers for all A and B tensors + // Use separate allocations to avoid copy assignment issues with RAII wrapper + std::vector in_packed_bufs; + in_packed_bufs.reserve(NumAElementwise + 1); + for(index_t i = 0; i <= NumAElementwise; ++i) + { + in_packed_bufs.emplace_back(in_total * sizeof(TIn)); + } + + std::vector wei_packed_bufs; + wei_packed_bufs.reserve(NumBElementwise + 1); + for(index_t i = 0; i <= NumBElementwise; ++i) + { + wei_packed_bufs.emplace_back(wei_total * sizeof(TWei)); + } + SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut)); - TIn* p_in_packed = static_cast(in_packed_buf.GetDeviceBuffer()); - TWei* p_wei_packed = static_cast(wei_packed_buf.GetDeviceBuffer()); + // Get packed buffer pointers + std::array p_ins_packed; + for(index_t i = 0; i <= NumAElementwise; ++i) + { + p_ins_packed[i] = static_cast(in_packed_bufs[i].GetDeviceBuffer()); + } + + std::array p_weis_packed; + for(index_t i = 0; i <= NumBElementwise; ++i) + { + p_weis_packed[i] = static_cast(wei_packed_bufs[i].GetDeviceBuffer()); + } + TOut* p_out_packed = static_cast(out_packed_buf.GetDeviceBuffer()); // Compute strides and allocate device arrays for pack/unpack @@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in, // Pack input and weight tensors to contiguous layout constexpr int block_size = 256; - strided_copy_kernel - <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total); - strided_copy_kernel - <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( - p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total); + + // Pack all A tensors + for(index_t i = 0; i <= NumAElementwise; ++i) + { + strided_copy_kernel + <<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total); + } + + // Pack all B tensors + for(index_t i = 0; i <= NumBElementwise; ++i) + { + strided_copy_kernel + <<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>( + p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total); + } + + // Prepare D tensor stride arrays on device + // NOTE: D tensors are NOT packed - they are used directly with their original strides + // to support broadcasting (e.g., BiasGK layout with zero strides) + std::vector d_stride_bufs; + std::array p_d_strides_dev = {}; + + if constexpr(NumDElementwise > 0) + { + d_stride_bufs.reserve(NumDElementwise); + + for(index_t i = 0; i < NumDElementwise; ++i) + { + // Allocate and copy strides to device + d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t)); + p_d_strides_dev[i] = static_cast(d_stride_bufs[i].GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i], + d_strides[i].data(), + d_strides[i].size() * sizeof(index_t), + hipMemcpyHostToDevice)); + } + } + + // Create device arrays of pointers + SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*)); + SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*)); + SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*)); + SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*)); + + TIn** d_ins_ptrs = static_cast(ins_ptrs_buf.GetDeviceBuffer()); + TWei** d_weis_ptrs = static_cast(weis_ptrs_buf.GetDeviceBuffer()); + TD** d_ds_ptrs = static_cast(ds_ptrs_buf.GetDeviceBuffer()); + index_t** d_d_strides_ptrs = static_cast(d_strides_ptrs_buf.GetDeviceBuffer()); + + HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs, + p_ins_packed.data(), + (NumAElementwise + 1) * sizeof(TIn*), + hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs, + p_weis_packed.data(), + (NumBElementwise + 1) * sizeof(TWei*), + hipMemcpyHostToDevice)); + + if constexpr(NumDElementwise > 0) + { + // D tensors use original pointers (not packed) to support broadcasting + std::array p_ds_dev; + for(index_t i = 0; i < NumDElementwise; ++i) + { + p_ds_dev[i] = p_ds[i]; + } + + HIP_CHECK_ERROR(hipMemcpy( + d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs, + p_d_strides_dev.data(), + NumDElementwise * sizeof(index_t*), + hipMemcpyHostToDevice)); + } // Build conv parameter vectors for kernel invocation std::vector conv_strides(ndim); @@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in, if(ndim == 1) { - naive_conv_fwd_packed<1, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<1, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in, } else if(ndim == 2) { - naive_conv_fwd_packed<2, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<2, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in, } else // 3D { - naive_conv_fwd_packed<3, - TIn, - TWei, - TOut, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation> - <<>>(p_in_packed, - p_wei_packed, + naive_conv_fwd_packed_multi_abd<3, + NumAElementwise, + NumBElementwise, + NumDElementwise, + TIn, + TWei, + TOut, + TD, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> + <<>>(d_ins_ptrs, + d_weis_ptrs, + d_ds_ptrs, + d_d_strides_ptrs, p_out_packed, G, N, @@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in, // Memory automatically freed by SimpleDeviceMem destructors } +// Original naive_conv_fwd - now a zero-overhead wrapper +template +inline void naive_conv_fwd(const TIn* p_in, + const TWei* p_wei, + TOut* p_out, + const ck::utils::conv::ConvParam& conv_param, + InElementwiseOperation in_element_op = InElementwiseOperation{}, + WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{}, + OutElementwiseOperation out_element_op = OutElementwiseOperation{}, + hipStream_t stream = nullptr) +{ + std::array p_ins = {p_in}; + std::array p_weis = {p_wei}; + std::array p_ds = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins, + p_weis, + p_ds, + p_out, + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op, + stream); +} + } // namespace ref } // namespace ck diff --git a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp index 0a7b58b310..50b65357a2 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp @@ -22,9 +22,39 @@ struct SimpleDeviceMem HIP_CHECK_ERROR(hipMalloc(static_cast(&p_mem_), mem_size)); } + // Delete copy operations (resource should not be copied) + SimpleDeviceMem(const SimpleDeviceMem&) = delete; + SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete; + + // Define move operations + SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_) + { + other.p_mem_ = nullptr; + } + + SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept + { + if(this != &other) + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + p_mem_ = other.p_mem_; + other.p_mem_ = nullptr; + } + return *this; + } + void* GetDeviceBuffer() { return p_mem_; } - ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + ~SimpleDeviceMem() + { + if(p_mem_) + { + (void)hipFree(p_mem_); + } + } void* p_mem_; }; @@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src, } } +namespace detail { + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result, + Op&& element_op, + const DataType* const* tensor_ptrs, + long_index_t element_offset, + std::index_sequence) +{ + element_op(result, tensor_ptrs[Is][element_offset]...); +} + +// Generic helper for A and B tensors (works in all directions) +template +__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result, + Op&& element_op, + const DataType* primary_ptr, + const DataType* const* extra_ptrs, + long_index_t extra_base_offset, + long_index_t element_offset) +{ + const DataType* tensor_ptrs[NumExtraTensors + 1]; + tensor_ptrs[0] = primary_ptr; + + static_for<1, NumExtraTensors + 1, 1>{}( + [&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; }); + + apply_multi_tensor_impl(result, + element_op, + tensor_ptrs, + element_offset, + std::make_index_sequence{}); +} + +// Helper for parameter pack expansion (D tensors) +template +__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out, + Op&& element_op, + float computed_value, + const float* d_values, + std::index_sequence) +{ + float temp_out; + element_op(temp_out, computed_value, d_values[Is]...); + result_out = type_convert(temp_out); +} + +// Specialized helper for D tensors with stride calculations and float conversion +template +__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out, + Op&& element_op, + float computed_value, + const DDataType* const* p_ds, + const index_t* const* p_d_strides, + index_t g, + index_t n, + index_t c_or_k, + long_index_t spatial_linear_index) +{ + if constexpr(NumDTensors == 0) + { + element_op(result_out, computed_value); + } + else + { + float d_values[NumDTensors]; + + // Compute all D tensor indices and convert to float + static_for<0, NumDTensors, 1>{}([&](auto i) { + const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] + + c_or_k * p_d_strides[i][2] + spatial_linear_index; + d_values[i] = type_convert(p_ds[i][d_idx]); + }); + + apply_d_tensor_impl(result_out, + element_op, + computed_value, + d_values, + std::make_index_sequence{}); + } +} + +} // namespace detail + } // namespace ref } // namespace ck diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index a0f9b9ac25..bf5ffcb5d2 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -17,6 +17,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" namespace ck { namespace profiler { @@ -129,7 +130,10 @@ bool profile_conv_bwd_data_impl(int do_verification, out_device_buf.ToDevice(output.mData.data()); wei_device_buf.ToDevice(weight.mData.data()); - if(do_verification) + // profile device Conv instances + bool pass = true; + + if(do_verification == 1) { auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData gpu_ref_input(in_g_n_c_wis_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * + input_device_result.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization + + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(gpu_ref_input.mData.data()); + } + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData gpu_ref_output(out_g_n_k_wos_desc); + if(do_verification == 2) + { + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(gpu_ref_output.mData.data()); + } using DeviceOp = ck::tensor_operation::device::DeviceConvFwd(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "gpu_ref_output : ", gpu_ref_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index 50cd58eec3..2a282edbc8 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -21,6 +21,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -156,8 +157,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, bias_device_buf.ToDevice(bias.mData.data()); // run reference op - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + if constexpr(BiasGK) + { + // For GK bias layout: G*K, zero strides for N and spatial dimensions + d_strides_vec[0] = K; + d_strides_vec[1] = 0; + d_strides_vec[2] = 1; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_strides_vec[3 + i] = 0; + } + } + else + { + // Full GNKHW layout - same as output + ck::ranges::copy(out_g_n_k_wos_desc.GetStrides(), d_strides_vec.begin()); + } + + std::array d_ptrs = { + reinterpret_cast(bias_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp index 3f4905c110..b439428cda 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bilinear_impl.hpp @@ -22,6 +22,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" namespace ck { namespace profiler { @@ -129,8 +130,9 @@ bool profile_grouped_conv_fwd_bilinear_impl( wei_device_buf.ToDevice(weight.mData.data()); d_device_buf.ToDevice(d_tensor.mData.data()); - if(do_verification) + if(do_verification == 1) { + // CPU reference auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd< NDimSpatial, InDataType, @@ -167,6 +169,61 @@ bool profile_grouped_conv_fwd_bilinear_impl( host_output(idx) = ck::type_convert(out_val); }); } + else if(do_verification == 2) + { + // GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + std::vector d_strides_vec(NDimSpatial + 3); + + d_lengths_vec[0] = conv_param.G_; + d_lengths_vec[1] = conv_param.N_; + d_lengths_vec[2] = conv_param.K_; + for(ck::index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(conv_param.output_spatial_lengths_[i]); + } + + // D tensor has same layout as output + ck::ranges::copy(d_host_tensor_descriptor.GetStrides(), d_strides_vec.begin()); + + std::array d_ptrs = { + reinterpret_cast(d_device_buf.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer())}; + + ck::ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + DDataType>( // Explicitly specify D tensor type + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + bilinear_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + out_device_buf.FromDevice(host_output.mData.data()); + } std::string best_op_name; float best_avg_time = 0; diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index acdc937a33..9444996c25 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -7,6 +7,7 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convinvscale.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_scale.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "profiler/common.hpp" @@ -150,7 +151,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, std::cout << "scale_out: " << scale_out << std::endl; // run reference op - if(do_verification) + if(do_verification == 1) { std::cout << "\nVerifying algorithm against reference convolution..." << std::endl; @@ -200,6 +201,57 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, } }); } + else if(do_verification == 2) + { + // GPU reference + // WORKAROUND: For int8_t with Scale, use CPU post-processing to match CPU reference + // Pure GPU approach fails int8 test (see 2026-01-07-int8-scale-debugging.md) + if constexpr(std::is_same_v && + std::is_same_v) + { + // Compute conv to CShuffleDataType (float), then post-process on CPU + DeviceMem gpu_ref_c_dev(sizeof(CShuffleDataType) * c.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_c_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + PassThrough{}); + + ck::hip_check_error(hipDeviceSynchronize()); + + Tensor gpu_c(out_g_n_k_wos_desc); + gpu_ref_c_dev.FromDevice(gpu_c.mData.data()); + + // Post-process on CPU to match CPU reference behavior + host_output.ForEach([&](auto&, auto idx) { + const auto conv_shuffle = ck::type_convert(gpu_c(idx)); + const auto conv_val = ck::type_convert(conv_shuffle); + out_element_op(host_output(idx), conv_val); + }); + } + else + { + // Normal path for non-int8 or non-Scale cases + DeviceMem gpu_ref_out_dev(sizeof(OutDataType) * + device_output.mDesc.GetElementSpaceSize()); + + ck::ref::naive_conv_fwd( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(gpu_ref_out_dev.GetDeviceBuffer()), + conv_param, + in_element_op, + wei_element_op, + out_element_op); + + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_out_dev.FromDevice(host_output.mData.data()); + } + } std::string best_op_name; float best_avg_time = 0; @@ -239,7 +291,7 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, best_gb_per_sec = gb_per_sec; } - if(do_verification) + if(do_verification == 1) { out_device_buf.FromDevice(device_output.mData.data()); @@ -259,6 +311,27 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, << std::endl; } } + else if(do_verification == 2) + { + out_device_buf.FromDevice(device_output.mData.data()); + + pass = + pass & ck::utils::check_err(device_output, + host_output, + "Error: Device and GPU ref results do not match!", + get_rtol(), + get_atol()); + + if(do_log) + { + LogRangeAsType(std::cout << "input : ", input.mData, ",") << std::endl; + LogRangeAsType(std::cout << "weight: ", weight.mData, ",") << std::endl; + LogRangeAsType(std::cout << "gpu_ref_output : ", host_output.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "device_output: ", device_output.mData, ",") + << std::endl; + } + } } else { diff --git a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp index 98f466a2b3..3e4eb07a64 100644 --- a/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data_xdl.cpp @@ -46,7 +46,7 @@ class TestConvndBwdData : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/convnd_fwd/convnd_fwd_xdl.cpp b/test/convnd_fwd/convnd_fwd_xdl.cpp index a2fdcaf870..0377b01bb2 100644 --- a/test/convnd_fwd/convnd_fwd_xdl.cpp +++ b/test/convnd_fwd/convnd_fwd_xdl.cpp @@ -47,7 +47,7 @@ class TestConvndFwd : public ::testing::Test ck::tensor_layout::convolution::NDHWK>>, DataType, DataType, - DataType>(true, // do_verification + DataType>(2, // do_verification: 2 = GPU reference 1, // init_method integer value false, // do_log false, // time_kernel diff --git a/test/gpu_reference/CMakeLists.txt b/test/gpu_reference/CMakeLists.txt index 443818feb3..d1c3908849 100644 --- a/test/gpu_reference/CMakeLists.txt +++ b/test/gpu_reference/CMakeLists.txt @@ -4,6 +4,9 @@ add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp) target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility) +add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp) +target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility) + add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp) target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility) diff --git a/test/gpu_reference/gpu_reference_utils.hpp b/test/gpu_reference/gpu_reference_utils.hpp index fc017c8734..88306d51a4 100644 --- a/test/gpu_reference/gpu_reference_utils.hpp +++ b/test/gpu_reference/gpu_reference_utils.hpp @@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType } } +// Forward convolution with D tensor support +template +bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& d_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& d_dev, + DeviceMem& output_dev, + OutElementOp out_element_op) +{ + using InElementOp = tensor_operation::element_wise::PassThrough; + using WeiElementOp = tensor_operation::element_wise::PassThrough; + + // Create D tensor lengths and strides for GPU reference + std::vector d_lengths_vec(NDimSpatial + 3); + d_lengths_vec[0] = params.G_; + d_lengths_vec[1] = params.N_; + d_lengths_vec[2] = params.K_; + for(index_t i = 0; i < NDimSpatial; ++i) + { + d_lengths_vec[3 + i] = static_cast(params.output_spatial_lengths_[i]); + } + + std::vector d_strides_vec = + ref::compute_conv_tensor_strides(d_lengths_vec, params.num_dim_spatial_); + + std::array d_ptrs = { + reinterpret_cast(d_dev.GetDeviceBuffer())}; + std::array, 1> d_lengths = {d_lengths_vec}; + std::array, 1> d_strides = {d_strides_vec}; + + // Call GPU reference with D tensor + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer())}; + + ref::naive_conv_fwd_multi_abd<0, + 0, + 1, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + OutDataType>( // Explicitly specify TD = OutDataType + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + InElementOp{}, + WeiElementOp{}, + out_element_op); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> d_tensors_ref = {d_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + InElementOp{}, + WeiElementOp{}, + out_element_op, + {}, // A tensors + {}, // B tensors + d_tensors_ref); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + +// Forward convolution with multiple A/B tensor support +template +bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params, + const Tensor& input_cpu, + const Tensor& weight_cpu, + const Tensor& a_extra_cpu, + const Tensor& b_extra_cpu, + DeviceMem& input_dev, + DeviceMem& weight_dev, + DeviceMem& a_extra_dev, + DeviceMem& b_extra_dev, + DeviceMem& output_dev, + InElementOp in_element_op, + WeiElementOp wei_element_op) +{ + using OutElementOp = tensor_operation::element_wise::PassThrough; + + // Call GPU reference with extra A and B tensors + std::array in_ptrs = { + reinterpret_cast(input_dev.GetDeviceBuffer()), + reinterpret_cast(a_extra_dev.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(weight_dev.GetDeviceBuffer()), + reinterpret_cast(b_extra_dev.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; + + ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(output_dev.GetDeviceBuffer()), + params, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + OutElementOp{}); + + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Run CPU reference + std::vector strides_long(params.conv_filter_strides_.begin(), + params.conv_filter_strides_.end()); + std::vector dilations_long(params.conv_filter_dilations_.begin(), + params.conv_filter_dilations_.end()); + std::vector pads_long(params.input_left_pads_.begin(), + params.input_left_pads_.end()); + + Tensor input_ref = input_cpu; + Tensor weight_ref = weight_cpu; + Tensor output_ref( + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params)); + + std::array, 1> a_tensors_ref = {a_extra_cpu}; + std::array, 1> b_tensors_ref = {b_extra_cpu}; + + auto ref_conv = tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_arg = ref_conv.MakeArgument(input_ref, + weight_ref, + output_ref, + strides_long, + dilations_long, + pads_long, + pads_long, + in_element_op, + wei_element_op, + OutElementOp{}, + a_tensors_ref, + b_tensors_ref, + {}); + ref_invoker.Run(ref_arg); + + // Copy result from device and compare + Tensor output_gpu(output_ref.mDesc); + output_dev.FromDevice(output_gpu.mData.data()); + HIP_CHECK_ERROR(hipDeviceSynchronize()); + + // Compare results + return ck::utils::check_err(output_gpu, output_ref); +} + } // namespace test } // namespace ck diff --git a/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp new file mode 100644 index 0000000000..ebe1e9695c --- /dev/null +++ b/test/gpu_reference/test_gpu_reference_conv_fwd_multi_abd.cpp @@ -0,0 +1,319 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "gpu_reference_utils.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" + +using namespace ck; +using ck::test::ConvKernelType; + +// ==================== D Tensor (Bias) Tests ==================== + +template +bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::AddClamp; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor bias(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(bias, bias_dev); + + // Test with AddClamp (bias operation with clamping) + AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias) +{ + auto params = test::conv_test_shapes::get_3d_small(); + bool result = test_conv_gpu_ref_with_bias<3, + float, + float, + float, + tensor_layout::convolution::GNCDHW, + tensor_layout::convolution::GKCZYX, + tensor_layout::convolution::GNKDHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bias<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g4(); + bool result = test_conv_gpu_ref_with_bias<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== D Tensor (Bilinear) Tests ==================== + +template +bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::Bilinear; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor d_tensor(out_g_n_k_wos_desc); // Same shape as output + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(d_tensor, d_dev); + + // Test with Bilinear: y = alpha * conv_result + beta * d_tensor + Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5 + + return test::test_conv_fwd_with_d_tensor_impl( + params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_bilinear<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_bilinear<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +// ==================== Multiple A/B (ScaleAdd) Tests ==================== + +template +bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params) +{ + using tensor_operation::element_wise::ScaleAdd; + + // Create tensor descriptors + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(params); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(params); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(params); + + // Create tensors + Tensor input(in_g_n_c_wis_desc); + Tensor weight(wei_g_k_c_xs_desc); + Tensor output(out_g_n_k_wos_desc); + Tensor a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input) + Tensor b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight) + + // Allocate device memory + DeviceMem input_dev(input.mData.size() * sizeof(InDataType)); + DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType)); + DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType)); + DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType)); + DeviceMem output_dev(output.mData.size() * sizeof(OutDataType)); + + // Initialize and copy tensors + test::initialize_and_copy_tensor(input, input_dev); + test::initialize_and_copy_tensor(weight, weight_dev); + test::initialize_and_copy_tensor(a_extra, a_extra_dev); + test::initialize_and_copy_tensor(b_extra, b_extra_dev); + + // Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1 + ScaleAdd in_element_op(2.0f); // scale factor for input + ScaleAdd wei_element_op(1.5f); // scale factor for weight + + return test::test_conv_fwd_with_multi_ab_impl(params, + input, + weight, + a_extra, + b_extra, + input_dev, + weight_dev, + a_extra_dev, + b_extra_dev, + output_dev, + in_element_op, + wei_element_op); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_small(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_medium(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + float, + float, + float, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} + +TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd) +{ + auto params = test::conv_test_shapes::get_2d_grouped_g2(); + bool result = test_conv_gpu_ref_with_scaleadd<2, + half_t, + half_t, + half_t, + tensor_layout::convolution::GNCHW, + tensor_layout::convolution::GKCYX, + tensor_layout::convolution::GNKHW>(params); + EXPECT_TRUE(result); +} diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp index b45f204b40..ea7289d6bf 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_bilinear.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -63,37 +63,62 @@ class TestGroupedConvndBwdData : public ::testing::Test Tensor& out, Tensor& d) { + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed( + conv_param); - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData(); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(in_g_n_c_wis_desc, d_lengths, d_strides); - auto ref_invoker = ref_conv.MakeInvoker(); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - Bilinear{alpha, beta}, - WeiElementOp{}, - OutElementOp{}, - {}, - {}, - d_tensors); + DeviceMem d_device_buf(sizeof(InDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); - ref_invoker.Run(ref_argument); + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_data_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + InDataType>( + static_cast(in_device_buf.GetDeviceBuffer()), + {static_cast(wei_device_buf.GetDeviceBuffer())}, + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{alpha, beta}, + WeiElementOp{}, + OutElementOp{}); + + in_device_buf.FromDevice(in_host.mData.data()); } bool PerformConvDataBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp index 84d013bca7..f1f985883c 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_conv_bwd_data_scale.cpp @@ -21,7 +21,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_data.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp" using ::ck::DeviceMem; using ::ck::HostTensorDescriptor; @@ -55,38 +55,24 @@ class TestGroupedConvndBwdData : public ::testing::Test void RunReference(ck::utils::conv::ConvParam& conv_param, Tensor& in_host, - Tensor& wei, - Tensor& out) + DeviceMem& wei_device_buf, + DeviceMem& out_device_buf) { - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdData /*Num D Elementwise - Tensors*/ - {}; + // GPU reference + DeviceMem gpu_ref_in_dev(sizeof(InDataType) * in_host.mDesc.GetElementSpaceSize()); + gpu_ref_in_dev.SetZero(); // bwd data needs zero initialization - auto ref_invoker = ref_conv.MakeInvoker(); + ck::ref::naive_conv_bwd_data( + static_cast(gpu_ref_in_dev.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + InElementOp{alpha}, + WeiElementOp{}, + OutElementOp{}); - auto ref_argument = ref_conv.MakeArgument(in_host, - wei, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{alpha}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); + ck::hip_check_error(hipDeviceSynchronize()); + gpu_ref_in_dev.FromDevice(in_host.mData.data()); } bool PerformConvDataScale(ck::utils::conv::ConvParam& conv_param, const ck::index_t split_k) @@ -121,10 +107,11 @@ class TestGroupedConvndBwdData : public ::testing::Test DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize()); DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); - in_device_buf.ToDevice(in_device.mData.data()); out_device_buf.ToDevice(out.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + RunReference(conv_param, in_host, wei_device_buf, out_device_buf); + std::array out_lengths{}; std::array out_strides{}; std::array wei_lengths{}; @@ -149,8 +136,6 @@ class TestGroupedConvndBwdData : public ::testing::Test copy(conv_param.input_left_pads_, input_left_pads); copy(conv_param.input_right_pads_, input_right_pads); - RunReference(conv_param, in_host, wei, out); - using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD& out, Tensor& d) { - std::array, NumDs> d_tensors = {d}; - auto ref_conv = - ck::tensor_operation::host::ReferenceConvBwdWeight{}; + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in, - wei_host, - out, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - InElementOp{}, - WeiElementOp{alpha, beta}, - OutElementOp{}, - {}, - {}, - d_tensors); + // Prepare D tensor with correct strides for GPU kernel + std::vector d_lengths; + std::vector d_strides; + auto copy_dims = [](const auto& desc, auto& lengths, auto& strides) { + const auto& l = desc.GetLengths(); + const auto& s = desc.GetStrides(); + lengths.assign(l.begin(), l.end()); + strides.assign(s.begin(), s.end()); + }; + copy_dims(wei_g_k_c_xs_desc, d_lengths, d_strides); - ref_invoker.Run(ref_argument); + std::array, NumDs> d_lengths_array = {d_lengths}; + std::array, NumDs> d_strides_array = {d_strides}; + + DeviceMem d_device_buf(sizeof(WeiDataType) * d.mDesc.GetElementSpaceSize()); + d_device_buf.ToDevice(d.mData.data()); + + std::array p_ds = { + static_cast(d_device_buf.GetDeviceBuffer())}; + + DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_host.mDesc.GetElementSpaceSize()); + DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpaceSize()); + + in_device_buf.ToDevice(in.mData.data()); + out_device_buf.ToDevice(out.mData.data()); + + ck::ref::naive_conv_bwd_weight_multi_abd<0, + 0, + NumDs, + InLayout, + WeiLayout, + OutLayout, + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp, + WeiDataType>( + {static_cast(in_device_buf.GetDeviceBuffer())}, + static_cast(wei_device_buf.GetDeviceBuffer()), + {static_cast(out_device_buf.GetDeviceBuffer())}, + p_ds, + conv_param, + d_lengths_array, + d_strides_array, + InElementOp{}, + WeiElementOp{alpha, beta}, + OutElementOp{}); + + wei_device_buf.FromDevice(wei_host.mData.data()); } bool PerformConvWeightBilinear(ck::utils::conv::ConvParam& conv_param, diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp index 1b37f5eb4e..645aab0151 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_bilinear.cpp @@ -66,10 +66,10 @@ class TestGroupedConvndFwdBilinear : public ::testing::Test OutDataType, AComputeType, BComputeType, - IndexType>(true, // do_verification + IndexType>(2, // do_verification 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param, bilinear_op); } diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp index 199a50f0fd..e78e61f707 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_scaleadd_ab.cpp @@ -24,6 +24,7 @@ #include "ck/library/utility/convolution_parameter.hpp" #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" +#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp" using I8 = int8_t; using F16 = ck::half_t; @@ -131,39 +132,34 @@ bool profile_grouped_conv_fwd_scaleadd_ab_impl(int do_verification, wei_device_buf.ToDevice(weight.mData.data()); wei_bias_device_buf.ToDevice(weight_bias.mData.data()); - // Run reference op + // Run GPU reference if(do_verification) { - const std::array, NumAs - 1> elementwise_a_tensors = {input_bias}; - const std::array, NumBs - 1> elementwise_b_tensors = {weight_bias}; - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + std::array in_ptrs = { + reinterpret_cast(in_device_buf.GetDeviceBuffer()), + reinterpret_cast(in_bias_device_buf.GetDeviceBuffer())}; + std::array wei_ptrs = { + reinterpret_cast(wei_device_buf.GetDeviceBuffer()), + reinterpret_cast(wei_bias_device_buf.GetDeviceBuffer())}; + std::array d_ptrs = {}; + std::array, 0> d_lengths = {}; + std::array, 0> d_strides = {}; - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weight, - host_output, - conv_param.conv_filter_strides_, - conv_param.conv_filter_dilations_, - conv_param.input_left_pads_, - conv_param.input_right_pads_, - in_element_op, - wei_element_op, - out_element_op, - elementwise_a_tensors, - elementwise_b_tensors); + ck::ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>( + in_ptrs, + wei_ptrs, + d_ptrs, + reinterpret_cast(out_device_buf.GetDeviceBuffer()), + conv_param, + d_lengths, + d_strides, + in_element_op, + wei_element_op, + out_element_op); - // init host output to zero - host_output.SetZero(); + HIP_CHECK_ERROR(hipDeviceSynchronize()); - ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(host_output.mData.data()); } std::string best_op_name; diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp index d1706d4cec..68a8b016e3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp.cpp @@ -49,7 +49,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, false /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp index fef485a950..2c04b52b4f 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_clamp.cpp @@ -50,7 +50,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, Clamp>( - true, // do_verification + 2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp index a78a17cbf4..78cfe126a3 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_gk_bias_clamp.cpp @@ -44,7 +44,7 @@ class TestGroupedConvndFwd : public ::testing::Test DataType, IndexType, true /*BiasGK*/>( - true, // do_verification + 2, // do_verification 1, // init_method: integer value false, // do_log false, // time_kernel diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp index b4179cae62..b2a9cff231 100644 --- a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_scale.cpp @@ -58,10 +58,10 @@ class TestGroupedConvndFwdScale : public ::testing::Test OutDataType, ck::tensor_operation::element_wise::Scale, InDataType, - InDataType>(true, // do_verification + InDataType>(2, // do_verification: 2 = GPU reference 1, // init_method: integer value false, // do_log - true, // time_kernel + false, // time_kernel param); } EXPECT_TRUE(pass); From 3d67e6c4927a9daea9076fab75b23fb44fdc22b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Tue, 27 Jan 2026 10:04:11 +0100 Subject: [PATCH 22/28] [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624) * [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err * Update test_grouped_convnd_fwd_tile.cpp * Update test_grouped_convnd_fwd_tile.cpp * Update conv_tuning_params.hpp * clang format fix * Update CMakeLists.txt --- .../factory/helpers/ck/conv_tuning_params.hpp | 3 + .../ck_tile/conv_tile_tuning_params.hpp | 8 +++ .../ck_tile/builder/testing/validation.hpp | 12 +++- .../builder/include/ck_tile/builder/types.hpp | 2 + .../configs/tests/ndhwgc_bf16.conf | 6 +- .../configs/tests/ndhwgc_fp16.conf | 6 +- .../configs/tests/ndhwgc_fp32.conf | 6 +- .../configs/tests/nhwgc_bf16.conf | 6 +- .../configs/tests/nhwgc_fp16.conf | 6 +- .../configs/tests/nhwgc_fp32.conf | 6 +- include/ck_tile/host/check_err.hpp | 2 +- .../grouped_convolution_forward_tile_algs.hpp | 55 +++++++++++++++++-- test/grouped_convnd_fwd/CMakeLists.txt | 13 ++--- .../test_grouped_convnd_fwd_tile.cpp | 29 +++++----- 14 files changed, 114 insertions(+), 46 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c..3b1ea65695 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm() case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break; case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break; case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM."; default: throw "Unknown PipelineVersion"; @@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion() case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K."; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM."; + case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE."; case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only; default: throw "Unknown GridwiseGemmPipelineVersion"; } @@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_pipeline::v3; case PipelineVersion::V4: return ck_pipeline::v4; case PipelineVersion::V5: return ck_pipeline::v5; + case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile."; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp index b7df0e4d0e..12482f3206 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp @@ -91,6 +91,13 @@ struct TilePipelineType using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; }; +template <> +struct TilePipelineType +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; +}; + template consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() { @@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion() case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3; case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4; case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5; + case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6; case PipelineVersion::WEIGHT_ONLY: throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version."; default: throw "Unknown block GEMM PipelineVersion"; diff --git a/experimental/builder/include/ck_tile/builder/testing/validation.hpp b/experimental/builder/include/ck_tile/builder/testing/validation.hpp index 158f271e21..8410a71b15 100644 --- a/experimental/builder/include/ck_tile/builder/testing/validation.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/validation.hpp @@ -51,6 +51,9 @@ struct ValidationReport /// The number of elements which were bitwise 0. uint64_t zero_elements; + // Max error. + double max_error; + /// @brief Check whether both the output and reference tensor were both all zeros. /// /// If both tensors are all zero, it indicates either an incorrect testing setup @@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name, // Initial pass: count errors // Allocate and reset counter - auto d_counters = alloc_buffer(sizeof(uint64_t) * 2); - check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2)); + auto d_counters = alloc_buffer(sizeof(uint64_t) * 3); + check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3)); auto d_error_count = &reinterpret_cast(d_counters.get())[0]; auto d_zero_count = &reinterpret_cast(d_counters.get())[1]; + auto d_max_error = &reinterpret_cast(d_counters.get())[2]; tensor_foreach(descriptor.get_lengths(), [=](auto index) { using CKType = typename factory::internal::DataTypeToCK
::type; @@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name, const auto r = static_cast(type_convert(b)); const auto err = std::abs(o - r); + atomicMax(d_max_error, err); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { // We expect the number of errors to be very low, so just use an atomic @@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name, check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); uint64_t zero_count = 0; check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost)); + double max_error = 0; + check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost)); // TODO: Gather detailed coordinates. @@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name, .wrong_elements = error_count, .total_elements = descriptor.get_element_size(), .zero_elements = zero_count, + .max_error = max_error, }); return reports_.back().is_ok(); diff --git a/experimental/builder/include/ck_tile/builder/types.hpp b/experimental/builder/include/ck_tile/builder/types.hpp index c4cca05e52..dad123bae5 100644 --- a/experimental/builder/include/ck_tile/builder/types.hpp +++ b/experimental/builder/include/ck_tile/builder/types.hpp @@ -157,6 +157,7 @@ enum class PipelineVersion V3, V4, V5, + V6, WEIGHT_ONLY }; @@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver) case V3: return "V3"; case V4: return "V4"; case V5: return "V5"; + case V6: return "V6"; case WEIGHT_ONLY: return "WEIGHT_ONLY"; default: return "Unknown"; } diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/ndhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_bf16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf index 9222a0858f..7cd2a3d85e 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp16.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf index b9704c8100..e7ea32680d 100644 --- a/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf +++ b/experimental/grouped_convolution_tile_instances/configs/tests/nhwgc_fp32.conf @@ -20,9 +20,9 @@ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 32, Filter1x1Stri DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Filter1x1Stride1Pad0, 16, 16, 8, 8, 8, 8, 8, 1, 2, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> -DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> +# DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 128, 64, Filter1x1Stride1Pad0, 32, 32, 2, 2, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v5> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Default, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 128, 256, 32, Filter1x1Stride1Pad0, 32, 32, 2, 4, 8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1> diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index a1be8027b2..2ba3b1e7c3 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -19,7 +19,7 @@ namespace ck_tile { /** @brief Maximum number of error values to display when checking errors */ -constexpr int ERROR_DETAIL_LIMIT = 128; +constexpr int ERROR_DETAIL_LIMIT = 16; /** @brief 8-bit floating point type */ using F8 = ck_tile::fp8_t; diff --git a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp index 9f7227a699..9accf6e336 100644 --- a/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp +++ b/profiler/include/profiler/grouped_convolution_forward_tile_algs.hpp @@ -7,6 +7,7 @@ #include "../../experimental/builder/test/utils/conv_algorithm_type_utils.hpp" #include "grouped_convolution_signatures.hpp" +#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp" #include "ck_tile/builder/testing/filter_extent.hpp" #include "ck_tile/builder/testing/conv/fwd.hpp" @@ -14,6 +15,9 @@ #include "ck_tile/builder/testing/conv/reference.hpp" #include "ck_tile/builder/conv_builder.hpp" +// Temporary disable builder validate since we don't have deduced rtol, atol support +#define ENABLE_BUILDER_VALIDATE 0 + namespace ck_tile::builder::profiling { namespace ckb = ck_tile::builder; @@ -117,22 +121,63 @@ run_grouped_conv_forward_tile_algs(const ckt::Args& args, auto ref_conv = ReferenceInstance{}; [[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get()); +#if ENABLE_BUILDER_VALIDATE == 0 + using DataType = + std::conditional_t>; + const auto conv_param = args.to_ck_tile_conv_param(); + + const std::size_t output_bytes_num = conv_param.template GetOutputByte(); + std::vector out(output_bytes_num / sizeof(DataType)); + std::vector ref(output_bytes_num / sizeof(DataType)); + HIP_CHECK_ERROR( + hipMemcpy(&ref.data()[0], reference.get().output, output_bytes_num, hipMemcpyDeviceToHost)); + + const ck_tile::index_t GemmK = std::accumulate(conv_param.filter_spatial_lengths_.cbegin(), + conv_param.filter_spatial_lengths_.cend(), + 1, + std::multiplies()) * + conv_param.C_; + float max_accumulated_value = *std::max_element(ref.begin(), ref.end()); + const auto rtol = ck_tile::get_relative_threshold(GemmK); + const auto atol = + ck_tile::get_absolute_threshold(max_accumulated_value, GemmK); +#endif + [[maybe_unused]] auto run_alg = [&](auto&& run_alg_func) { std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf); if(is_supported) { + best_avg_time = std::min(best_avg_time, avg_time); + best_op_name = best_avg_time < avg_time ? best_op_name : op_name; + std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name + << std::endl; + +#if ENABLE_BUILDER_VALIDATE const auto errors = ckt::validate(args, outputs, reference.get()).get_errors(); for(const auto& error : errors) { valid = false; std::cout << "Number of incorrect values: " << error.wrong_elements - << " Is all zero:" << error.is_all_zero() << std::endl; + << " Is all zero:" << error.is_all_zero() + << " max err: " << error.max_error << std::endl; } - best_avg_time = std::min(best_avg_time, avg_time); - best_op_name = best_avg_time < avg_time ? best_op_name : op_name; - std::cout << "Perf: " << std::setw(10) << avg_time << " ms,"; +#else + HIP_CHECK_ERROR( + hipMemcpy(&out.data()[0], outputs.output, output_bytes_num, hipMemcpyDeviceToHost)); + valid = ck_tile::check_err(out, ref, "Error: Incorrect results!", rtol, atol); +#endif + + std::cout << "Relative error threshold: " << rtol + << " Absolute error threshold: " << atol << std::endl; + } + else + { + std::cout << " " << op_name << std::endl; } - std::cout << " " << op_name << std::endl; }; if constexpr(SIGNATURE == SIGNATURE_NHWGC_FP16_FWD) diff --git a/test/grouped_convnd_fwd/CMakeLists.txt b/test/grouped_convnd_fwd/CMakeLists.txt index 6f8b71679c..725c5716d9 100644 --- a/test/grouped_convnd_fwd/CMakeLists.txt +++ b/test/grouped_convnd_fwd/CMakeLists.txt @@ -21,13 +21,12 @@ endif() if(GPU_TARGETS MATCHES "gfx9") if(CK_EXPERIMENTAL_BUILDER) - # TODO: Reenable after the instance fixes - # add_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) - # target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) - # if(TARGET device_grouped_conv_fwd_tile_instances) - # target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) - # endif() + add_gtest_executable(test_grouped_convnd_fwd_tile test_grouped_convnd_fwd_tile.cpp) + target_compile_options(test_grouped_convnd_fwd_tile PRIVATE -Wno-global-constructors -Wno-undef -Wno-c++20-compat) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE gtest_main getopt::getopt utility) + if(TARGET device_grouped_conv_fwd_tile_instances) + target_link_libraries(test_grouped_convnd_fwd_tile PRIVATE device_grouped_conv_fwd_tile_instances) + endif() endif() endif() diff --git a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp index 068811cf00..fe517572ff 100644 --- a/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp +++ b/test/grouped_convnd_fwd/test_grouped_convnd_fwd_tile.cpp @@ -13,6 +13,8 @@ // TODO: Remove limitation of conv fwd gpu reference which does not support right pad #define CK_CONV_FWD_REF_SKIP_RIGHT_PAD_CASES 1 +// TODO: Remove this limitation after gpu reference fix +#define ENABLE_BHALF_GROUPED_CONV_FWD_TESTS 0 static ck::index_t args_mask = 0xffff; static ck::index_t instance_index = -1; @@ -67,7 +69,10 @@ class TestGroupedConvndFwdTile : public ::testing::Test auto inputs = alloc_inputs(args); auto outputs = alloc_outputs(args); - ckt::init_inputs(args, inputs.get()); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().input, args.make_input_descriptor(), -5, 5); + ckt::init_tensor_buffer_uniform_fp( + inputs.get().weight, args.make_weight_descriptor(), -5, 5); std::cout << args.make_input_descriptor() << std::endl; std::cout << args.make_weight_descriptor() << std::endl; @@ -150,13 +155,12 @@ using KernelTypes2d = ::testing::Types, - SignatureDetails<2, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NHWGC, - ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 2, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NHWGC, + ckb::TensorLayout::GKYXC, ckb::TensorLayout::NHWGK >> + ; +#endif using KernelTypes3d = ::testing::Types, - SignatureDetails<3, - ckb::DataType::BF16, - ckb::DataType::FP32, - ckb::TensorLayout::NDHWGC, - ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK>>; +#if ENABLE_BHALF_GROUPED_CONV_FWD_TESTS +SignatureDetails < 3, ckb::DataType::BF16, ckb::DataType::FP32, ckb::TensorLayout::NDHWGC, + ckb::TensorLayout::GKZYXC, ckb::TensorLayout::NDHWGK >> + ; +#endif template class TestGroupedConvndFwdTile2d : public TestGroupedConvndFwdTile From b66597ed96180ce21e7e6a6678dfc232ed07c800 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 27 Jan 2026 05:07:27 -0800 Subject: [PATCH 23/28] Add build time optimization documentation (#3608) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase, including: - Replacing recursive templates with pack expansion (O(N) → O(1) depth) - Using named functors instead of lambdas to share instantiations - Replacing template recursion with constexpr loops - Using fold expressions for accumulation operations These techniques can significantly reduce build times for template-heavy code. --- include/ck/BUILD_TIME_OPTIMIZATION.md | 225 ++++++++++++++++++++++++++ 1 file changed, 225 insertions(+) create mode 100644 include/ck/BUILD_TIME_OPTIMIZATION.md diff --git a/include/ck/BUILD_TIME_OPTIMIZATION.md b/include/ck/BUILD_TIME_OPTIMIZATION.md new file mode 100644 index 0000000000..94b292b878 --- /dev/null +++ b/include/ck/BUILD_TIME_OPTIMIZATION.md @@ -0,0 +1,225 @@ +# Build Time Optimization + +Tracking issue: [#3575](https://github.com/ROCm/composable_kernel/issues/3575) + +This document describes techniques for reducing C++ template instantiation overhead in the Composable Kernel codebase. + +## Why Build Time Matters + +Composable Kernel relies heavily on C++ template metaprogramming to achieve GPU kernels with no runtime abstraction penalty. However, deep template instantiation can significantly impact build times. A single translation unit may trigger hundreds of thousands of template instantiations, with each instantiation adding to compile time. + +## Key Types + +This codebase uses compile-time types to enable zero-overhead abstractions: + +- `Number` - compile-time integer, enables static dispatch and compile-time arithmetic +- `Sequence` - compile-time integer sequence, used for dimension ordering and index manipulation +- `Tuple` - heterogeneous container holding different types, used for tensor descriptors and transforms + +These types allow the compiler to fully unroll loops, eliminate branches, and inline all operations - producing GPU kernels with no runtime abstraction cost. + +## Optimization Techniques + +### 1. Replace Recursive Templates with Pack Expansion + +Recursive template patterns create O(N) instantiation depth - the compiler must instantiate each level before proceeding to the next: + +``` +sequence_gen_impl<5, F> + → sequence_gen_impl<4, F> + → sequence_gen_impl<3, F> + → ... +``` + +Using `__make_integer_seq` (Clang/MSVC) combined with pack expansion reduces this to constant depth - the compiler generates the entire sequence in one step internally, without recursive template instantiation. + +**Before** (O(N) recursive instantiation): + +```cpp +template +struct sequence_gen_impl +{ + using type = typename sequence_gen_impl{}), Is...>::type; +}; + +template +struct sequence_gen_impl<0, F, Is...> +{ + using type = Sequence; +}; +``` + +**After** (constant depth using compiler intrinsic + pack expansion): + +```cpp +namespace detail { + +template +struct sequence_gen_helper +{ + // Apply functor F to all indices via pack expansion + // F{}(Number<0>{}), F{}(Number<1>{}), ..., F{}(Number{}) + template + using apply = Sequence{})...>; +}; + +} // namespace detail + +template +struct sequence_gen +{ + // __make_integer_seq produces + // sequence_gen_helper with constant depth + using type = + typename __make_integer_seq::template apply; +}; +``` + +Note: This document assumes C++17 or later. While `std::make_integer_sequence` (introduced in C++14) is the standard library facility for generating integer sequences, it only produces `std::integer_sequence`. We use `__make_integer_seq` directly because it accepts any template as its first argument, enabling this pattern where the helper class receives the index pack directly. + +### 2. Replace Lambdas with Named Functors + +Each lambda expression creates a unique closure type, causing separate template instantiations at every call site. Named functors share a single type across all uses. + +**Before** (lambda creates unique instantiations at each call site): + +```cpp +// The lambda inside transform_tensor_descriptor: +generate_tuple([](auto i) { return Sequence{}; }, Number{}); +``` + +**After** (named functor shares instantiations): + +```cpp +// Define functor once +struct generate_identity_sequence +{ + template + __host__ __device__ constexpr auto operator()(Number) const + { + return Sequence{}; + } +}; + +// Use everywhere - shares instantiations +generate_tuple(generate_identity_sequence{}, Number{}); +``` + +This reduced `transform_tensor_descriptor` instantiations from 388 to 32 (92% reduction). + +**Example: container_concat** + +```cpp +// Before: lambda creates unique type per call site +// (unpack2 applies a functor to all elements from both tuples) +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2([](auto&&... zs) { return make_tuple(forward(zs)...); }, tx, ty); +} + +// After: named functor shares instantiations +struct make_tuple_functor +{ + template + __host__ __device__ constexpr auto operator()(Ts&&... xs) const + { + return make_tuple(forward(xs)...); + } +}; + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2(make_tuple_functor{}, tx, ty); +} +``` + +This reduced `container_concat` instantiations from 186 to 93 (50% reduction). + +**Example: make_uniform_tuple** + +For patterns that create tuples with repeated values: + +```cpp +// Before: unique lambda type at each call site +generate_tuple([](auto) { return some_value; }, Number{}); + +// After: dedicated helper function +template +__host__ __device__ constexpr auto make_uniform_tuple(T&& value) +{ + return detail::make_uniform_tuple_impl(static_cast(value), make_index_sequence{}); +} + +// Usage +make_uniform_tuple(some_value); +``` + +### 3. Use Constexpr Loops Instead of Template Recursion + +Template recursion creates N template instantiations for N iterations. A constexpr loop executes at compile time but only requires a single template instantiation. While both are O(N) in complexity, constexpr loops are significantly faster because they avoid the overhead of template instantiation. + +**Before** (O(N) template instantiations): + +```cpp +// Simplified example - actual CK code used more complex recursive patterns +template +struct find_source_index_impl +{ + static constexpr index_t value = + (Seq::template At() == Target) ? Pos : find_source_index_impl::value; +}; + +template +struct find_source_index_impl +{ + static constexpr index_t value = -1; // not found +}; +``` + +**After** (single instantiation with constexpr loop): + +```cpp +template +__host__ __device__ constexpr index_t find_source_index(Sequence) +{ + // Simplified example - actual implementation handles empty sequences + constexpr index_t values[] = {Is...}; + for(index_t i = 0; i < sizeof...(Is); ++i) + if(values[i] == Target) return i; + return -1; // not found +} +``` + +This reduced `sequence_map_inverse` instantiations from 45 to 10 (78% reduction) and wall-clock time by 95%. + +### 4. Use Fold Expressions for Accumulation + +Fold expressions (C++17) can replace recursive template patterns for accumulation operations. + +**Before** (uses helper utilities that hide template recursion: `generate_tuple` recursively constructs a tuple of N elements, and `container_reduce` recursively reduces that tuple): + +```cpp +const auto element_space_size = container_reduce( + generate_tuple([&](auto i) { + return (lengths[i] - Number<1>{}) * strides[i]; + }, Number{}), + math::plus{}, Number<1>{}); +``` + +**After** (single fold expression): + +```cpp +template +__host__ __device__ constexpr auto compute_element_space_size( + const Tuple& lengths, + const Tuple& strides, + Sequence) +{ + return (LongNumber<1>{} + ... + + ((lengths[Number{}] - Number<1>{}) * strides[Number{}])); +} +``` + +This reduced `calculate_element_space_size` instantiations from 24 to 10 (58% reduction) and wall-clock time by 73%. From ad3954f11958d57b363d371bf867843417205124 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Tue, 27 Jan 2026 08:46:53 -0500 Subject: [PATCH 24/28] Enable bwd weight splitk autodeduction with cap --- .../gpu/device/device_grouped_conv_bwd_weight.hpp | 2 -- .../device_grouped_conv_bwd_weight_multiple_d.hpp | 2 -- .../device_grouped_conv_bwd_weight_explicit.hpp | 15 +++------------ ...onv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 12 +++--------- ...ed_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp | 11 +++-------- ...e_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp | 12 ++++-------- ...evice_grouped_conv_bwd_weight_xdl_cshuffle.hpp | 11 +++-------- ...ce_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp | 12 +++--------- 8 files changed, 19 insertions(+), 58 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 58da96e2f0..eadfa29c9f 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,8 +11,6 @@ namespace ck { namespace tensor_operation { namespace device { -#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 - template ()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index bc072a7019..76f79ac14a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -585,7 +585,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -602,6 +601,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -611,7 +613,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -988,13 +989,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index d3bf2a364a..033e0b8745 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -677,7 +677,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -688,9 +687,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -947,12 +948,6 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index 3f8093afe1..b2ae092c27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -511,7 +511,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -528,6 +528,9 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 const auto k_batch_max = math::integer_divide_ceil((gemmK - 1), KPerBlock); k_batch_ = std::min(k_batch_, k_batch_max); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -537,7 +540,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1040,12 +1042,6 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 976b6f1ef8..3e19b082ee 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -651,7 +651,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -662,9 +661,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle calculate_mn_grid_size(gemmM, gemmN) * Conv_G_; k_batch_ = get_best_occupancy_k_batch_value(active_workgroups_per_cu.max_occupancy_, grid_size); + + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); } else -#endif { k_batch_ = split_k; } @@ -1083,12 +1084,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 2121be00d1..3fb6e92039 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -556,7 +556,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -573,6 +572,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 const auto k_batch_max = static_cast((gemmK - 1) / K0PerBlock); k_batch_ = std::max(std::min(k_batch_, k_batch_max), 1); + // Cap k_batch_ to 128 to avoid accuracy issues + k_batch_ = std::min(k_batch_, 128); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "[SPLIT-K AUTODEDUCE] k_batch max value: " << k_batch_max @@ -582,7 +584,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else -#endif { k_batch_ = split_k; } @@ -1360,13 +1361,6 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(arg.k_batch_ < 0) - { - return false; - } -#endif - const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); const index_t GemmK = From 74eb200c7348243dc87adf166f42d0b6c20061e4 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Tue, 27 Jan 2026 09:23:12 -0500 Subject: [PATCH 25/28] Fix error threshold calculations --- .../profile_grouped_conv_bwd_weight_impl.hpp | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 3a9f14e595..078bc51077 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -364,26 +364,36 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; - // Calculate number of accumulations accounting for split_k - const int num_accums = - static_cast(output.GetElementSize() / conv_param.K_ / split_k_value); - - // Additional tolerance for split_k accumulation if needed - int total_accums = num_accums; - if(split_k_value > 1) - { - total_accums = std::max(num_accums, static_cast(split_k_value)); - } + const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums_split_k = split_k_value; + // Calculate thresholds + auto rtol = + ck::utils::get_relative_threshold( + num_accums / num_accums_split_k); + auto atol = + ck::utils::get_absolute_threshold( + max_accumulated_value / num_accums_split_k, + num_accums / num_accums_split_k); + // Calculate error due to split_k accumulation + auto rtol_split_k = + ck::utils::get_relative_threshold( + num_accums_split_k); + auto atol_split_k = + ck::utils::get_absolute_threshold( + max_accumulated_value, num_accums_split_k); + // Use higher threshold + rtol = std::max(rtol, rtol_split_k); + atol = std::max(atol, atol_split_k); // Perform GPU verification (max value computed internally on GPU) const std::size_t tensor_size = weight_device_result.mDesc.GetElementSpaceSize(); auto gpu_result = - ck::profiler::gpu_verify( - wei_device_buf.GetDeviceBuffer(), - gpu_ref_wei_buf.GetDeviceBuffer(), - total_accums, - tensor_size); + ck::profiler::gpu_verify(wei_device_buf.GetDeviceBuffer(), + gpu_ref_wei_buf.GetDeviceBuffer(), + rtol, + atol, + tensor_size); if(!gpu_result) { From 55d8e9b4f02c760b0fa956e48c7f74dd7160b976 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 28 Jan 2026 01:56:55 -0500 Subject: [PATCH 26/28] Add missing logic to wmma multiple d kernel --- ...bwd_weight_multiple_d_wmma_cshuffle_v3.hpp | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 76f79ac14a..f662ff834f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -22,6 +22,7 @@ #include #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/tensor_operation/gpu/device/impl/split_k_arg.hpp" +#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/host_utility/device_prop.hpp" @@ -524,6 +525,44 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 decltype(GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( CGridDesc_M_N{}, 1, 1)); + struct ActiveWorkgroupsPerCU + { + ActiveWorkgroupsPerCU() + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return; + } + constexpr int dynamic_smem_size = 0; + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + int max_occupancy = 0; + + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // TODO: implement + } + else + { + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, + kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>, + BlockSize, + dynamic_smem_size)); + } + max_occupancy_ = std::max(1, max_occupancy); + } + int max_occupancy_; + }; + struct Argument : public BaseArgument, public ArgumentSplitK { Argument( @@ -574,6 +613,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { + static ActiveWorkgroupsPerCU active_workgroups_per_cu; + constexpr index_t spatial_offset = 3; std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, end(b_g_n_c_wis_lengths), From 0eee2d3392801c9d7ed022a07819e3fca378c313 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 28 Jan 2026 09:18:03 -0500 Subject: [PATCH 27/28] Fix threshold calculation --- .../profiler/profile_grouped_conv_bwd_weight_impl.hpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 078bc51077..afc88150ed 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -366,6 +366,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, const index_t num_accums = output.GetElementSize() / conv_param.K_; const index_t num_accums_split_k = split_k_value; + // Get maximum accumulated value from reference + const std::size_t tensor_size = + weight_device_result.mDesc.GetElementSpaceSize(); + max_accumulated_value = + gpu_reduce_max(gpu_ref_wei_buf.GetDeviceBuffer(), tensor_size); // Calculate thresholds auto rtol = ck::utils::get_relative_threshold( @@ -385,9 +390,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, rtol = std::max(rtol, rtol_split_k); atol = std::max(atol, atol_split_k); - // Perform GPU verification (max value computed internally on GPU) - const std::size_t tensor_size = - weight_device_result.mDesc.GetElementSpaceSize(); + // Perform GPU verification auto gpu_result = ck::profiler::gpu_verify(wei_device_buf.GetDeviceBuffer(), gpu_ref_wei_buf.GetDeviceBuffer(), From 029efffeb51198ee5a5a51a8668715d14cf8b181 Mon Sep 17 00:00:00 2001 From: "Graner, Johannes" Date: Wed, 28 Jan 2026 07:24:09 -0500 Subject: [PATCH 28/28] Update test with new applicability --- .../test_grouped_convnd_bwd_weight_interface_xdl.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index bce6da4b68..5aa0b13c07 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -184,5 +184,5 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; this->split_k_ = -1; bool is_supported = this->template Run<2>(); - EXPECT_FALSE(is_supported); + EXPECT_TRUE(is_supported); }