From bd6070fb5cf80b7440bb4f692faae3a85a1fa353 Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:26:40 -0700 Subject: [PATCH] Compile-time optimize threadwise slice transfer (#4673) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Profiling with `-ftime-trace` on representative translation units (e.g., `device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp`) revealed that **92% of frontend time was spent in template instantiation**. The primary bottleneck was redundant instantiation of identical helper logic across multiple threadwise transfer class variants. Each `ThreadwiseTensorSliceTransfer_v*` class independently contained its own copy of the same helper computations — serpentine traversal, coordinate stepping, thread scratch descriptors, lambda-like functors, and compile-time constants — duplicated across 13 header files. When a typical GEMM or convolution kernel TU includes blockwise operations (e.g., `blockwise_gemm_xdlops.hpp`), it pulls in multiple transfer variants simultaneously, causing the compiler to instantiate the same helper logic multiple times with the same template arguments. This was compounded by the helpers being defined as members of the outer `ThreadwiseTensorSliceTransfer_v*` classes, which carry 14+ template parameters. Functions like `ComputeForwardSweep` depend only on their two argument types, but as inline members of the outer class, the compiler was forced to create separate instantiations for every unique combination of all outer parameters (data types, descriptors, vector widths, etc.) — even when most of those parameters had no effect on the helper's output. ## Technical Details ### The Fix: Shared Helper Struct Hierarchy Duplicated logic was extracted into a standalone helper hierarchy in `threadwise_tensor_slice_transfer_util.hpp`: ``` ThreadwiseTransferHelper_Base (I0..I16, MoveSliceWindow, ComputeThreadScratchDescriptor, | ComputeForwardSteps, ComputeBackwardSteps, MakeVectorContainerTuple) +-- ThreadwiseTransferHelper_Serpentine (ComputeForwardSweep, ComputeMoveOnDim, ComputeDataIndex, | ComputeCoordinateResetStep, VectorSizeLookupTable, VectorOffsetsLookupTable) +-- ThreadwiseTransferHelper_SFC (ComputeSFCCoordinateResetStep) ``` Each helper method is now parameterized **only by what it actually uses**: - `ComputeForwardSweep(idx, lengths)` — parameterized only by the two argument types, not by `SrcData`, `DstData`, `SrcDesc`, etc. - `ComputeForwardSteps(desc, scalar_per_access)` — parameterized only by the descriptor and access sequence types. - `ComputeCoordinateResetStep()` — parameterized only by the four values it actually needs. This reduces template instantiation work in two ways: 1. **Across different transfer variants** (v3r1 vs v3r2 vs v3r1_gather): the compiler reuses a single instantiation instead of creating one per variant. 2. **Across different outer class instantiations** (fp16 vs bf16 vs int8): the compiler reuses the helper instantiation because the helper doesn't depend on the data type at all. ### Refactored Headers **13 headers** now delegate to the shared helpers instead of duplicating logic: - Serpentine family: v3r1, v3r2, v3r1_gather, v3r1_dequant - SFC family: v6r1, v6r1r2, v6r2, v6r3, v7r2, v7r3, v7r3_scatter - Dead code removed: v4r1, v5r1 ### Additional Fixes Found During Refactoring - Two latent bugs in v3r2 (`forward_sweep` indexing, `GetDstCoordinateResetStep` extraction) - Dead `SrcCoordStep` variables in v4r1 and v5r1 - Unused `scale_element_op_` member in v3r1_dequant (restored with note) ### Net Code Change +1,428 / -2,297 lines (~870 lines removed). ## Test Plan ### Unit Tests 28 host-side gtests in `test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp` covering the full helper hierarchy: | Suite | Tests | What is verified | |-------|-------|------------------| | ThreadwiseTransferHelperBase | 6 | Compile-time constants, inheritance, `MoveSliceWindow` with `ResetCoordinateAfterRun` true/false in 2D and 3D | | ThreadwiseTransferHelperSerpentine | 9 | `ComputeForwardSweep` (even/odd row, 1D), `ComputeMoveOnDim` (inner complete/incomplete), `ComputeDataIndex`, `ComputeCoordinateResetStep`, `VectorSizeLookupTable`, `VectorOffsetsLookupTable` | | ThreadwiseTransferHelperSFC | 6 | `ComputeSFCCoordinateResetStep` — single access, 2D row-major, 2D column-major, 3D batch, even/odd inner access counts | | ThreadwiseTransferHelperInheritance | 3 | Serpentine and SFC derive from Base, are not related to each other | | DetailFunctors | 4 | `lambda_scalar_per_access`, `lambda_scalar_step_in_vector`, `lambda_scalar_per_access_for_src_and_dst` (same dim, different dims) | ### Semantic Equivalence GPU ISA comparison using `--cuda-device-only -S` confirmed identical assembly output (modulo `__hip_cuid_*` metadata) between baseline and refactored code. ## Test Results All measurements on a 384-core machine, `-j64`, freshly rebooted, near-idle. ### Targeted Builds (affected targets only) | Target | Baseline | Refactored | Wall-clock Delta | CPU Delta | |--------|----------|------------|-----------------|-----------| | `device_grouped_conv2d_fwd_instance` (160 TUs) | 7m 37s / 189m CPU | 6m 53s / 161m CPU | **-9.7%** | **-14.9%** | | `device_grouped_conv3d_fwd_instance` (185 TUs) | 9m 49s / 202m CPU | 6m 42s / 182m CPU | **-31.8%** | **-10.0%** | | **Combined** | **17m 27s / 392m CPU** | **13m 35s / 344m CPU** | **-22.2%** | **-12.4%** | ### Full Project Build (8,243 targets) | Metric | Baseline | Refactored | Delta | |--------|----------|------------|-------| | Wall-clock | 103m 38s | 111m 56s | +8.0%* | | CPU time | 4705m 7s | 4648m 17s | **-1.2%** | \*Wall-clock inflated by external load spike during refactored build (load 90 vs 66). CPU time is the reliable metric. ### Context ~15% of all build targets (1,262 / 8,243) transitively include the modified headers. These are primarily GEMM and convolution kernel instantiations — the core compute workloads. The 12-15% CPU savings on affected targets is diluted to 1.2% across the full project because 85% of targets are unaffected. ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Claude Opus 4.6 --- .../threadwise_tensor_slice_transfer_util.hpp | 481 ++++++++++- .../threadwise_tensor_slice_transfer_v3r1.hpp | 740 ++++++----------- ...ise_tensor_slice_transfer_v3r1_dequant.hpp | 618 ++------------- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 434 ++-------- .../threadwise_tensor_slice_transfer_v3r2.hpp | 441 ++--------- .../threadwise_tensor_slice_transfer_v4r1.hpp | 2 - .../threadwise_tensor_slice_transfer_v5r1.hpp | 3 - .../threadwise_tensor_slice_transfer_v6r1.hpp | 49 +- ...hreadwise_tensor_slice_transfer_v6r1r2.hpp | 49 +- .../threadwise_tensor_slice_transfer_v6r2.hpp | 63 +- .../threadwise_tensor_slice_transfer_v6r3.hpp | 75 +- .../threadwise_tensor_slice_transfer_v7r2.hpp | 106 +-- .../threadwise_tensor_slice_transfer_v7r3.hpp | 108 +-- ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 107 +-- test/CMakeLists.txt | 1 + .../threadwise_transfer_helper/CMakeLists.txt | 4 + .../test_threadwise_transfer_helper.cpp | 748 ++++++++++++++++++ 17 files changed, 1747 insertions(+), 2282 deletions(-) create mode 100644 test/threadwise_transfer_helper/CMakeLists.txt create mode 100644 test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp index 5035fe23d0..c2b54e2ba3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -1,19 +1,28 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + namespace ck { -// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory -// and sometimes useless instructions: -// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument -// instead -// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same -// tensor coordinate instead -// 3. Don't use a pointer to VGPR buffer, use vector instead +/** + * @file threadwise_tensor_slice_transfer_util.hpp + * @brief Shared helper class hierarchy for threadwise tensor slice transfer variants. + * + * Provides a three-tier inheritance structure: + * + * - @ref ThreadwiseTransferHelper_Base -- generic coordinate/descriptor utilities + * - @ref ThreadwiseTransferHelper_Serpentine -- serpentine (snake/zigzag) traversal + * - @ref ThreadwiseTransferHelper_SFC -- SpaceFillingCurve traversal + */ namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor + +/** @brief Functor returning ScalarPerVector for dimension VectorDim, 1 otherwise. */ template struct lambda_scalar_per_access { @@ -23,6 +32,7 @@ struct lambda_scalar_per_access } }; +/** @brief Functor returning 1 for dimension VectorDim, 0 otherwise. */ template struct lambda_scalar_step_in_vector { @@ -32,8 +42,10 @@ struct lambda_scalar_step_in_vector } }; -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor +/** + * @brief Functor computing scalar-per-access for combined src/dst vector dimensions. + * Returns lcm when both src and dst share the same vector dimension. + */ template -struct lambda_wave_cluster_dimension +} // namespace detail + +/** + * @brief Base helper with methods shared by all threadwise transfer variants. + * + * Both ThreadwiseTransferHelper_Serpentine and ThreadwiseTransferHelper_SFC + * inherit from this class. Contains generic coordinate stepping, thread scratch + * descriptor construction, and compile-time index constants. + */ +struct ThreadwiseTransferHelper_Base { - __host__ __device__ constexpr auto operator()(index_t i) const + /** + * @name Compile-time index constants + * @{ + */ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + /** @} */ + + /** + * @brief Move the slice window by a step, optionally fusing coordinate reset. + * + * If the coordinate was not reset after RunRead/RunWrite, the reset step is + * added to the movement step to avoid a separate coordinate adjustment. + * + * @tparam ResetCoordinateAfterRun Whether the coordinate was already reset. + * @param desc Tensor descriptor. + * @param coord Tensor coordinate to move (modified in place). + * @param slice_origin_step_idx Step index for the slice window movement. + * @param get_reset_step Callable returning the coordinate reset step. + */ + template + __host__ __device__ static void MoveSliceWindow(const Desc& desc, + Coord& coord, + const StepIdx& slice_origin_step_idx, + GetCoordinateResetStepFunc get_reset_step) { - if((nDim - i) == 3) - return WaveNum; - else - return 1; + const auto adjusted_step_idx = ResetCoordinateAfterRun + ? slice_origin_step_idx + : slice_origin_step_idx + get_reset_step(); + + const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx); + + move_tensor_coordinate(desc, coord, adjusted_step); + } + + /** + * @brief Build the thread-local scratch tensor descriptor. + * + * Creates a transformed tensor descriptor where the vector dimension is merged + * with an additional dimension of size ScalarPerVector, enabling vector-typed + * access to the scratch buffer. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam VectorDim Which dimension is vectorized. + * @tparam ScalarPerVector_ Number of scalars per vector load/store. + * @return Transformed tensor descriptor for the thread scratch buffer. + */ + template + __host__ __device__ static constexpr auto ComputeThreadScratchDescriptor() + { + constexpr index_t nDim = SliceLengths::Size(); + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(access_lengths_and_vector_length); + + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == VectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(access_lengths_and_vector_length[i], + access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == VectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = generate_identity_sequences(); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + /** + * @brief Compute forward (+1) coordinate steps for each dimension. + * + * Returns a tuple of nDim coordinate steps, where step[i] moves by + * +scalar_per_access[i] in dimension i and 0 in all other dimensions. + * + * @param desc Tensor descriptor. + * @param scalar_per_access Per-dimension access widths (Sequence type). + */ + template + __host__ __device__ static constexpr auto + ComputeForwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + return generate_tuple( + [&](auto i) { + MultiIndex step_idx; + + static_for<0, nDim, 1>{}( + [&](auto j) { step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + /** + * @brief Compute backward (-1) coordinate steps for each dimension. + * + * Same as ComputeForwardSteps but with negated step values. + * + * @param desc Tensor descriptor. + * @param scalar_per_access Per-dimension access widths (Sequence type). + */ + template + __host__ __device__ static constexpr auto + ComputeBackwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + return generate_tuple( + [&](auto i) { + MultiIndex step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + /** + * @brief Create a tuple of default-constructed vector containers, one per data type. + * + * @tparam DataTypes Tuple of data types (e.g., Tuple). + * @tparam ScalarPerVector Number of scalars per vector. + * @return Tuple of vector_type_maker_t instances. + */ + template + __host__ __device__ static auto MakeVectorContainerTuple() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); } }; -} // namespace detail +/** + * @brief Serpentine (snake/zigzag) traversal helper. + * + * Provides methods for computing serpentine sweep directions, dimension movement + * decisions, and coordinate reset steps used by the v3r1 family of transfer classes. + * + * Used by: ThreadwiseTensorSliceTransfer_v3r1, v3r2, v3r1_gather, v3r1_dequant. + */ +struct ThreadwiseTransferHelper_Serpentine : ThreadwiseTransferHelper_Base +{ + /** + * @brief Binary decomposition of vector widths 0-16 into power-of-2 sub-load sizes. + * Index N gives the sequence of sub-load widths whose sum equals N. + * E.g. index 7 -> Sequence means loads of width 4, 2, 1. + */ + using VectorSizeLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + /** + * @brief Starting offsets for each sub-load in VectorSizeLookupTable. + * E.g. index 7 -> Sequence means offsets 0, 4, 6. + */ + using VectorOffsetsLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + /** + * @brief Compute serpentine sweep direction for each dimension. + * + * Determines whether each dimension should be traversed forward or backward + * based on the current position in the ordered access grid, implementing + * a zigzag (serpentine) traversal pattern. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @return Array of booleans: true = forward, false = backward. + */ + template + __host__ __device__ static constexpr auto + ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + constexpr index_t nDim = OrderedAccessLengths::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "ordered_access_idx and ordered_access_lengths must have same nDim"); + StaticallyIndexedArray_v2 forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + } + + /** + * @brief Determine which dimensions need coordinate movement at a given iteration. + * + * A dimension moves when it hasn't reached its end and all higher-priority + * (faster-varying) dimensions have completed their ranges. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @return Array of booleans: true = move coordinate on this dimension. + */ + template + __host__ __device__ static constexpr auto + ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + constexpr index_t nDim = OrderedAccessLengths::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "ordered_access_idx and ordered_access_lengths must have same nDim"); + StaticallyIndexedArray_v2 move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + + /** + * @brief Convert ordered access index to natural dimension order and apply scaling. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @param forward_sweep Per-dimension sweep direction. + * @param dim_access_order Mapping from ordered to natural dimension indices. + * @param scalar_per_access Per-dimension access widths. + * @return MultiIndex in natural dimension order, scaled by scalar_per_access. + */ + template + __host__ __device__ static constexpr auto + ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths, + const ForwardSweep& forward_sweep, + const DimAccessOrder& dim_access_order, + const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "all arguments to ComputeDataIndex must have same nDim"); + MultiIndex ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; + } + + /** + * @brief Compute the coordinate step needed to return to the origin after traversal. + * + * Determines where the coordinate ends up after a full serpentine traversal, + * then returns the negated position as the reset step. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam VectorDim Which dimension is vectorized. + * @tparam ScalarPerVector_ Number of scalars per vector load/store. + * @tparam DimAccessOrder Compile-time sequence mapping ordered to natural dims. + * @return MultiIndex representing the step to reset the coordinate to the origin. + */ + template + __host__ __device__ static constexpr auto ComputeCoordinateResetStep() + { + constexpr index_t nDim = SliceLengths::Size(); + static_assert(DimAccessOrder::Size() == nDim, + "SliceLengths and DimAccessOrder must have same nDim"); + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + constexpr auto ordered_access_lengths_minus_1 = generate_tuple( + [&](auto i) { return Number{}; }, Number{}); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths); + + constexpr auto reset_step = [&]() { + MultiIndex ordered_idx; + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + auto data_idx = + container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; + + MultiIndex step; + static_for<0, nDim, 1>{}([&](auto i) { step(i) = -data_idx[i]; }); + return step; + }(); + + return reset_step; + } +}; + +/** + * @brief SpaceFillingCurve traversal helper. + * + * Provides coordinate reset computation using SpaceFillingCurve's GetStepBetween + * method, which computes the step from the last access position back to the origin. + * + * Used by: ThreadwiseTensorSliceTransfer v6r1, v6r1r2, v6r2, v6r3, v7r2, v7r3, + * v7r3_scatter. + */ +struct ThreadwiseTransferHelper_SFC : ThreadwiseTransferHelper_Base +{ + /** + * @brief Compute the coordinate reset step using SpaceFillingCurve traversal. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam DimAccessOrder Compile-time sequence defining dimension access order. + * @tparam ScalarPerAccess Compile-time sequence of per-dimension access widths. + * @return MultiIndex representing the step from last access position to origin. + */ + template + __host__ __device__ static constexpr auto ComputeSFCCoordinateResetStep() + { + using SFC = SpaceFillingCurve>; + + constexpr auto num_access = SFC::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SFC::Index{}; + } + else + { + return SFC::GetStepBetween(Number{}, Number<0>{}); + } + } +}; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 7b9d136068..8b0b35935f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -44,263 +44,63 @@ template struct ThreadwiseTensorSliceTransfer_v3r1 { + // ===================================================================== + // Private type aliases and constants + // ===================================================================== + private: + using Helper = ThreadwiseTransferHelper_Serpentine; + static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I10 = Number<10>{}; - static constexpr auto I12 = Number<12>{}; - static constexpr auto I13 = Number<13>{}; - static constexpr auto I14 = Number<14>{}; - static constexpr auto I16 = Number<16>{}; - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); + static constexpr index_t PackedSize = is_same_v, pk_i4_t> ? 2 : 1; static constexpr auto SrcScalarPerVector = Number{}; static constexpr auto DstScalarPerVector = Number{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( - const SrcDesc& src_desc, - const Index& src_slice_origin, - const SrcElementwiseOperation& src_element_op, - const DstDesc& dst_desc, - const Index& dst_slice_origin, - const DstElementwiseOperation& dst_element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), - src_element_op_(src_element_op), - dst_element_op_(dst_element_op) + // ===================================================================== + // Private implementation methods (must be declared before public methods + // that call them) + // ===================================================================== + __device__ static constexpr auto GetSrcCoordinateResetStep() { - if constexpr((packed_size_v) > 1) - { - static_assert(is_same_v, remove_cvref_t>, - "SrcData != DstData"); - - static_assert( - SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, - "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); - - static_assert(SrcVectorDim == DstVectorDim, - "Packed data type does not support transpose"); - } + return Helper::ComputeCoordinateResetStep(); } - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + __device__ static constexpr auto GetDstCoordinateResetStep() { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + return Helper::ComputeCoordinateResetStep(); } - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + return Helper:: + ComputeThreadScratchDescriptor(); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - Number thread_scratch_id = Number{}) + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, - "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // make forward and backward steps - const auto src_forward_steps = ComputeForwardSteps(src_desc, src_scalar_per_access); - const auto src_backward_steps = ComputeBackwardSteps(src_desc, src_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); - - constexpr auto src_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - // maintain a container record is_src_valid, waiting for RunWrite use. - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - src_oob_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, is_src_valid); - - using dst_vector_type = vector_type_maker_t; - using dst_vector_t = typename dst_vector_type::type; - dst_vector_type op_r_v; - - constexpr auto get_elem_op_vec_len = []() { - if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack8_invocable) - return math::min(8, SrcScalarPerVector); - } - else if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack4_invocable) - return math::min(4, SrcScalarPerVector); - } - else if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack2_invocable) - return math::min(2, SrcScalarPerVector); - } - else - { - return 1; - } - }; - - constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - - using src_elem_op_vec_t = typename vector_type::type; - using dst_elem_op_vec_t = typename vector_type::type; - - using VectorSizeLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; - using VectorOffsetsLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; - - static_for<0, tuple_element_t::Size(), 1>{}( - [&](auto v_idx) { - constexpr auto VectorLoadSize = - tuple_element_t::At(v_idx); - constexpr auto LoadOffset = - tuple_element_t::At(v_idx); - - using src_vector_container = vector_type_maker_t; - using src_vector_container_t = typename src_vector_container::type; - - src_vector_container src_vector = - src_vector_container{src_buf.template Get( - src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; - - static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if - // needed - src_element_op_( - op_r_v.template AsType()(idx + LoadOffset), - src_vector.template AsType()[idx]); - }); - }); - - // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, - op_r_v.template AsType()[I0]); - - constexpr auto move_on_dim = - ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } + return make_naive_tensor_descriptor_packed(src_access_lengths); } - template - __device__ constexpr auto - GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) + __device__ static constexpr auto GetDstThreadScratchDescriptor() { - using vector_t = typename vector_type_maker::type::type; - return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); + return Helper:: + ComputeThreadScratchDescriptor(); } template @@ -327,14 +127,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = - ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -439,6 +239,194 @@ struct ThreadwiseTensorSliceTransfer_v3r1 #endif } + // ===================================================================== + // Public interface + // ===================================================================== + public: + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + if constexpr((packed_size_v) > 1) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); + + // calculate src data index + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // maintain a container record is_src_valid, waiting for RunWrite use. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, is_src_valid); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + else + { + return 1; + } + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, + tuple_element_t::Size(), + 1>{}([&](auto v_idx) { + constexpr auto VectorLoadSize = + tuple_element_t::At(v_idx); + constexpr auto LoadOffset = + tuple_element_t::At( + v_idx); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + src_vector_container src_vector = + src_vector_container{src_buf.template Get( + src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; + + static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if + // needed + src_element_op_(op_r_v.template AsType()(idx + LoadOffset), + src_vector.template AsType()[idx]); + }); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); + + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + template __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, @@ -470,21 +458,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); // make forward and backward steps - const auto dst_forward_steps = ComputeForwardSteps(dst_desc, dst_scalar_per_access); - const auto dst_backward_steps = ComputeBackwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = - ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); // calculate dst data index - constexpr auto dst_data_idx = ComputeDataIndex(ordered_dst_access_idx, - ordered_dst_access_lengths, - forward_sweep, - dst_dim_access_order, - dst_scalar_per_access); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -510,10 +499,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_buf.template Set( dst_coord_.GetOffset() / PackedSize, is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); constexpr auto move_on_dim = - ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -543,21 +532,19 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - __device__ static constexpr auto GetSrcCoordinateResetStep() + template + __device__ constexpr auto + GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) { - return ComputeCoordinateResetStep(); - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - return ComputeCoordinateResetStep(); + using vector_t = typename vector_type_maker::type::type; + return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - MoveSliceWindow( + Helper::MoveSliceWindow( src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } @@ -565,252 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - MoveSliceWindow( + Helper::MoveSliceWindow( dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } - __device__ static constexpr auto GetSrcThreadScratchDescriptor() - { - return ComputeThreadScratchDescriptor(); - } - - __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - return make_naive_tensor_descriptor_packed(src_access_lengths); - } - - __device__ static constexpr auto GetDstThreadScratchDescriptor() - { - return ComputeThreadScratchDescriptor(); - } - - protected: - // Helper function to compute forward sweep pattern - // I.e. if we should move forward or backward in each of tensor's dimensions - template - __device__ static constexpr auto - ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths) - { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<1, i, 1>{}( - [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - } - - // Compute which dimensions should have their coordinates updated during iteration - // A dimension moves when it hasn't reached its end and all higher priority dimensions - // have completed their ranges - template - __device__ static constexpr auto - ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths) - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - - // Compute data index from ordered access index, converting back to natural order - template - __device__ static constexpr auto - ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths, - const ForwardSweep& forward_sweep, - const DimAccessOrder& dim_access_order, - const ScalarPerAccess& scalar_per_access) - { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; - } - - // Compute forward coordinate steps for each dimension - template - __device__ static constexpr auto ComputeForwardSteps(const Desc& desc, - const ScalarPerAccess& scalar_per_access) - { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(desc, forward_step_idx); - }, - Number{}); - } - - // Compute backward coordinate steps for each dimension - template - __device__ static constexpr auto ComputeBackwardSteps(const Desc& desc, - const ScalarPerAccess& scalar_per_access) - { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(desc, backward_step_idx); - }, - Number{}); - } - - // Generic helper to compute thread scratch descriptor - template - __device__ static constexpr auto ComputeThreadScratchDescriptor() - { - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - constexpr auto access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == VectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(access_lengths_and_vector_length[i], - access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == VectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - // Generic helper to move slice window - template - __device__ static void MoveSliceWindow(const Desc& desc, - Coord& coord, - const Index& slice_origin_step_idx, - GetCoordinateResetStepFunc get_reset_step) - { - // if coord was not reset by RunRead/RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = ResetCoordinateAfterRun - ? slice_origin_step_idx - : slice_origin_step_idx + get_reset_step(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx); - - move_tensor_coordinate(desc, coord, adjusted_step); - } - - // Generic helper to compute coordinate reset step - template - __device__ static constexpr auto ComputeCoordinateResetStep() - { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto ordered_access_lengths_minus_1 = generate_tuple( - [&](auto i) { return Number{}; }, Number{}); - constexpr auto forward_sweep = - ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths); - - // calculate data index after last iteration, if it has not being reset - constexpr auto data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - scalar_per_access; - }(); - - // - constexpr auto reset_data_step = [&]() { - Index reset_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); - - return reset_data_step_; - }(); - - return reset_data_step; - } - + // ===================================================================== + // Private data members + // ===================================================================== private: static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_oob_thread_scratch_desc_ = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 2ddb34671a..7545c8c416 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -7,43 +7,12 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access_for_src_and_dst_idle -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; - -} // namespace detail - // Assume: // 1. src_desc and dst_desc are not known at compile-time // 2. SrcBuffer and DstBuffer are DynamicBuffer @@ -84,12 +53,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; + using Helper = ThreadwiseTransferHelper_Serpentine; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -139,7 +108,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! SrcBuffer and SrcData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -150,66 +118,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -227,22 +150,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + src_data_idx_seq, + src_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -284,7 +196,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! ScaleBuffer and ScaleData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto scale_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -295,66 +206,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_scale_access_lengths = container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); - // make forward steps - const auto scale_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(scale_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto scale_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(scale_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto scale_forward_steps = + Helper::ComputeForwardSteps(scale_desc, scale_scalar_per_access); + const auto scale_backward_steps = + Helper::ComputeBackwardSteps(scale_desc, scale_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_scale_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_scale_access_idx, ordered_scale_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_scale_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate scale data index - constexpr auto scale_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] - : ordered_scale_access_lengths[i] - 1 - - ordered_scale_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * - scale_scalar_per_access; - }(); + constexpr auto scale_data_idx = Helper::ComputeDataIndex(ordered_scale_access_idx, + ordered_scale_access_lengths, + forward_sweep, + scale_dim_access_order, + scale_scalar_per_access); constexpr auto scale_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -372,23 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // copy data from scale_vector_container into scale_thread_scratch_ scale_thread_scratch_.template SetAsType( - scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + scale_data_idx_seq, + scale_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = - ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_scale_access_idx, ordered_scale_access_lengths); // move scale coord static_for<0, nDim, 1>{}([&](auto i) { @@ -409,17 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant } }); }); - - // don't need to move scale coordinate back to slice origin - /* - if constexpr(SrcResetCoordinateAfterRun) - { - const auto scale_reset_step = - make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); - - move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); - } - */ } template @@ -460,10 +304,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + detail::lambda_scalar_per_access_for_src_and_dst{}, Number{}); constexpr auto access_lengths = SliceLengths{} / scalar_per_access; @@ -504,10 +348,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // Do fast numeric convert constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + detail::lambda_scalar_per_access_for_src_and_dst{}, Number{}); constexpr auto access_lengths = SliceLengths{} / scalar_per_access; @@ -528,15 +372,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant src_converted_thread_scratch_.template SetAsType( access_idx, - src_converted_vector_container.template AsType()[I0]); + src_converted_vector_container + .template AsType()[Helper::I0]); }); // Element-scale operation, expect packed multiplication static_ford{}([&](auto idx) { DstData dst_v; - constexpr auto scale_idx = Sequence{}; - // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), - // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + constexpr auto scale_idx = Sequence{}; src_element_op_(dst_v, src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); dst_thread_scratch_(idx) = dst_v; @@ -562,7 +405,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! SrcBuffer or DstBuffer data type is wrong"); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -573,66 +415,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -660,22 +457,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant dst_buf.template Set( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -707,293 +492,52 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; + return Helper::ComputeCoordinateResetStep(); } __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + Helper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + Helper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetScaleThreadScratchDescriptor() { - - constexpr auto scale_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; - - constexpr auto scale_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(scale_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(scale_access_lengths_and_vector_length[i], - scale_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } private: @@ -1002,11 +546,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant decltype(GetScaleThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - /* - template - struct ScaleThreadScratchDesc{}; - */ - // Registers, contain raw data loaded from global buffer using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + using Helper = ThreadwiseTransferHelper_Serpentine; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I10 = Number<10>{}; - static constexpr auto I12 = Number<12>{}; - static constexpr auto I13 = Number<13>{}; - static constexpr auto I14 = Number<14>{}; - static constexpr auto I16 = Number<16>{}; - static constexpr index_t PackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -142,7 +126,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather "wrong! SrcBuffer and SrcData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -156,66 +139,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -274,24 +214,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, - op_r_v.template AsType()[I0]); + .template SetAsType( + src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); + // Gather-specific: skip gather dimension during coordinate movement auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; + auto move_on_dim_ = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - move_on_dim_(i) &= i.value != ordered_gather_dim; - }); + static_for<0, nDim, 1>{}( + [&](auto i) { move_on_dim_(i) &= i.value != ordered_gather_dim; }); return move_on_dim_; }(); + // move src coord static_for<0, nDim, 1>{}([&](auto i) { if(move_on_dim[i]) @@ -351,38 +287,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -501,7 +413,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather "wrong! SrcBuffer or DstBuffer data type is wrong"); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -512,66 +423,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -599,22 +465,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather dst_buf.template Set( dst_coord_.GetOffset() / PackedSize, is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -644,10 +498,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather } } + // Gather-specific: src coordinate reset zeroes the gather dimension __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -658,29 +511,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto ordered_access_lengths_minus_1 = generate_tuple( + [&](auto i) { return Number{}; }, Number{}); + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() constexpr auto src_data_idx = [&]() { - Index ordered_idx; + MultiIndex ordered_idx; static_for<0, nDim, 1>{}([&](auto i) { ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; @@ -690,9 +527,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather src_scalar_per_access; }(); - // + // Gather-specific: don't reset the gather dimension constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; + MultiIndex reset_src_data_step_; static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; @@ -705,137 +542,32 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + Helper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + Helper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() @@ -850,50 +582,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index 3c7291cca3..24fbd66be6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -7,10 +7,11 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" #include "ck/utility/is_detected.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Assume: @@ -48,6 +49,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static constexpr index_t nSrc = SrcDescs::Size(); static constexpr index_t nDst = DstDescs::Size(); + using Helper = ThreadwiseTransferHelper_Serpentine; + // return a tuple of coordiantes for a tuple of tensor template {})); using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( const SrcDescs& src_descs, const StaticallyIndexedArray& src_slice_origins, @@ -101,7 +102,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2 Number thread_scratch_id = Number{}) { // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access_tuple = generate_tuple( [&](auto src_i) { return generate_sequence( @@ -129,40 +129,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - // make forward steps + // make forward and backward steps const auto src_forward_steps_tuple = generate_tuple( [&](auto src_i) { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = - (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; - }); - - return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); - }, - Number{}); + return Helper::ComputeForwardSteps(src_descs.At(src_i), + src_scalar_per_access_tuple.At(src_i)); }, Number{}); - // make backward steps const auto src_backward_steps_tuple = generate_tuple( [&](auto src_i) { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? -src_scalar_per_access_tuple.At(src_i)[i] - : 0; - }); - - return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); - }, - Number{}); + return Helper::ComputeBackwardSteps(src_descs.At(src_i), + src_scalar_per_access_tuple.At(src_i)); }, Number{}); @@ -171,39 +149,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_ford>{}( [&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths_tuple[j] + - ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = Helper::ComputeForwardSweep( + ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i)); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_src_access_idx[i] - : ordered_src_access_lengths_tuple.At(src_i)[i] - - 1 - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access_tuple.At(src_i); - }(); + constexpr auto src_data_idx = + Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths_tuple.At(src_i), + forward_sweep, + src_dim_access_order, + src_scalar_per_access_tuple.At(src_i)); constexpr auto src_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -227,24 +182,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2 .At(src_i) .template SetAsType( src_data_idx_seq, - src_vector_container.template AsType()[I0]); + src_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < - ordered_src_access_lengths_tuple.At(src_i)[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == - ordered_src_access_lengths_tuple.At(src_i)[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = Helper::ComputeMoveOnDim( + ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i)); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -287,18 +228,30 @@ struct ThreadwiseTensorSliceTransfer_v3r2 { // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE // (it requires to add Elementwise support in transpose_vectors) - static_ford{}([&](auto idx) { - const auto src_data_refs = generate_tie( - [&](auto src_i) -> const auto& { - return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; - }, - Number{}); + if constexpr(nSrc == 1 && nDst == 1) + { + // Fast path: direct element transfer, no generate_tie/unpack2 overhead + static_ford{}([&](auto idx) { + element_op_(dst_thread_scratch_tuple_.At(Number<0>{})(idx), + src_thread_scratch_tuple_[thread_scratch_id].At(Number<0>{})[idx]); + }); + } + else + { + // General path: use generate_tie + unpack2 for multi-src/dst + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); - auto dst_data_refs = generate_tie( - [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, - Number{}); - unpack2(element_op_, dst_data_refs, src_data_refs); - }); + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } } template @@ -311,7 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2 TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access_tuple = generate_tuple( [&](auto dst_i) { return generate_sequence( @@ -334,40 +286,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - // make forward steps + // make forward and backward steps const auto dst_forward_steps_tuple = generate_tuple( [&](auto dst_i) { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = - (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; - }); - - return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); - }, - Number{}); + return Helper::ComputeForwardSteps(dst_descs.At(dst_i), + dst_scalar_per_access_tuple.At(dst_i)); }, Number{}); - // make backward steps const auto dst_backward_steps_tuple = generate_tuple( [&](auto dst_i) { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? -dst_scalar_per_access_tuple.At(dst_i)[i] - : 0; - }); - - return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); - }, - Number{}); + return Helper::ComputeBackwardSteps(dst_descs.At(dst_i), + dst_scalar_per_access_tuple.At(dst_i)); }, Number{}); @@ -376,39 +306,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_ford>{}( [&](auto ordered_dst_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + - ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = Helper::ComputeForwardSweep( + ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i)); // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths_tuple.At(dst_i)[i] - - 1 - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access_tuple.At(dst_i); - }(); + constexpr auto dst_data_idx = + Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths_tuple.At(dst_i), + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access_tuple.At(dst_i)); constexpr auto dst_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -434,24 +341,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2 dst_bufs.At(dst_i).template Update( dst_coords_.At(dst_i).GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < - ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == - ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = Helper::ComputeMoveOnDim( + ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i)); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -491,121 +384,19 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; + return Helper::ComputeCoordinateResetStep(); } template __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access.At(dst_i); - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -649,103 +440,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = - container_push_back(sequence_to_tuple_of_number(src_access_lengths), - Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper::ComputeThreadScratchDescriptor(); } template __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = - container_push_back(sequence_to_tuple_of_number(dst_access_lengths), - Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper::ComputeThreadScratchDescriptor(); } __device__ static constexpr auto MakeSrcThreadScratchTuple() diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 74a964ddd8..45b638c842 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -42,8 +42,6 @@ struct ThreadwiseTensorSliceTransfer_v4r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bce2d453dc..5d14d66eb3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -44,9 +44,6 @@ struct ThreadwiseTensorSliceTransfer_v5r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 2e255e2500..fc0ec9128d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -40,11 +42,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, @@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = SrcResetCoordinateAfterRun - ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp index 43d4148dab..711f693f6f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -39,11 +41,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = SrcResetCoordinateAfterRun - ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index f036bc4312..f7e5aa3adf 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -43,12 +45,12 @@ struct ThreadwiseTensorSliceTransfer_v6r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, const Index& src0_slice_origin, const Src1Desc& src1_desc, @@ -141,7 +143,7 @@ struct ThreadwiseTensorSliceTransfer_v6r2 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -187,67 +189,30 @@ struct ThreadwiseTensorSliceTransfer_v6r2 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& src0_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src0ResetCoordinateAfterRun - ? src0_slice_origin_step_idx - : src0_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); - - move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& src1_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src1ResetCoordinateAfterRun - ? src1_slice_origin_step_idx - : src1_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); - - move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 7d53c1ac0d..79a6b5d3aa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -46,13 +48,13 @@ struct ThreadwiseTensorSliceTransfer_v6r3 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, const Index& src0_slice_origin, const Src1Desc& src1_desc, @@ -165,7 +167,7 @@ struct ThreadwiseTensorSliceTransfer_v6r3 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -221,82 +223,37 @@ struct ThreadwiseTensorSliceTransfer_v6r3 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& src0_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src0ResetCoordinateAfterRun - ? src0_slice_origin_step_idx - : src0_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); - - move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& src1_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src1ResetCoordinateAfterRun - ? src1_slice_origin_step_idx - : src1_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); - - move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& src2_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src2ResetCoordinateAfterRun - ? src2_slice_origin_step_idx - : src2_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); - - move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src2_desc, src2_coord_, src2_slice_origin_step_idx, GetCoordinateResetStep); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp index 6326f6cbda..64f9ac2243 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -55,6 +55,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + // return a tuple of coordiantes for a tuple of tensor template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -473,98 +465,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index b4ee81697e..c4fad23f70 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -56,6 +56,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + // return a tuple of coordiantes for a tuple of tensor template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -615,100 +607,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 732922c157..45bd6f3f8e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -63,6 +63,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static constexpr index_t nDst = DstDescs::Size(); using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); // return a tuple of coordiantes for a tuple of tensor @@ -134,17 +135,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -506,100 +497,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b0b5f1c82f..017391549a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -255,6 +255,7 @@ add_compile_options(-Wno-c++20-extensions) add_subdirectory(ck_tile) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) +add_subdirectory(threadwise_transfer_helper) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) diff --git a/test/threadwise_transfer_helper/CMakeLists.txt b/test/threadwise_transfer_helper/CMakeLists.txt new file mode 100644 index 0000000000..d157f19500 --- /dev/null +++ b/test/threadwise_transfer_helper/CMakeLists.txt @@ -0,0 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(test_threadwise_transfer_helper test_threadwise_transfer_helper.cpp) diff --git a/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp b/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp new file mode 100644 index 0000000000..0033fb0db8 --- /dev/null +++ b/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp @@ -0,0 +1,748 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +using namespace ck; + +// ============================================================================= +// ThreadwiseTransferHelper_Base tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperBase, CompileTimeConstants) +{ + EXPECT_EQ(ThreadwiseTransferHelper_Base::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I1.value, 1); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I2.value, 2); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I4.value, 4); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I8.value, 8); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I16.value, 16); +} + +TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySerpentine) +{ + // Serpentine inherits all constants from Base via public inheritance. + EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I16.value, 16); +} + +TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySFC) +{ + // SFC inherits all constants from Base via public inheritance. + EXPECT_EQ(ThreadwiseTransferHelper_SFC::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_SFC::I16.value, 16); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Base::MoveSliceWindow tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetAlreadyDone) +{ + /* + * Scenario: v3r1's MoveSrcSliceWindow after RunRead has already reset + * the coordinate back to the slice origin (SrcResetCoordinateAfterRun=true). + * + * 2D packed tensor (4 rows x 8 columns), modelling a tile transfer: + * + * col: 0 1 2 3 4 5 6 7 + * row 0: [*] . . . . . . . <-- start at (0,0), offset=0 + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * Step = (1, 0): move one row down. + * Reset step = (-3, 0): would move 3 rows up (irrelevant here). + * + * Since ResetCoordinateAfterRun=true, the reset step is NOT fused + * into the movement. The coordinate simply moves by the step alone. + * + * Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(1, 0); + + auto get_reset_step = []() { return make_multi_index(-3, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // Coordinate moved by step only: (0,0) -> (1,0) + // Offset in row-major packed layout: 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetFused) +{ + /* + * Scenario: v3r1's MoveSrcSliceWindow when RunRead did NOT reset + * the coordinate (SrcResetCoordinateAfterRun=false). This is the + * optimization path where MoveSliceWindow fuses the reset step + * with the movement step to save a separate coordinate adjustment. + * + * Same 2D packed tensor (4 rows x 8 columns): + * + * col: 0 1 2 3 4 5 6 7 + * row 0: [*] . . . . . . . <-- start at (0,0), offset=0 + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * Step = (2, 0): move two rows down. + * Reset step = (-1, 0): move one row up (e.g., undo traversal overshoot). + * + * Since ResetCoordinateAfterRun=false, MoveSliceWindow adds the + * reset step to the movement step before applying: + * adjusted_step = step + reset = (2,0) + (-1,0) = (1,0) + * + * Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(2, 0); + + auto get_reset_step = []() { return make_multi_index(-1, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // adjusted_step = (2,0) + (-1,0) = (1,0) + // Offset: 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_3D_ResetFused) +{ + /* + * Scenario: 3D packed tensor (2 x 4 x 8), modelling a typical GEMM + * intermediate buffer with SliceLengths = (batch, row, col). + * + * Layout (batch=0 shown, row-major packed): + * + * batch 0: + * col: 0 1 2 3 4 5 6 7 + * row 0: . . . . . . . . + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * batch 1: (same structure, offset += 4*8 = 32) + * + * Start at (0, 0, 0), offset=0. + * + * Step = (0, 2, 0): move 2 rows down within the same batch. + * Reset step = (0, -1, 0): undo 1 row of traversal overshoot. + * + * ResetCoordinateAfterRun=false, so steps are fused: + * adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0) + * + * Expected: (0,0,0) + (0,1,0) = (0,1,0) + * Offset in packed layout: 0*(4*8) + 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<2>{}, Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(0, 2, 0); + + auto get_reset_step = []() { return make_multi_index(0, -1, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0) + // Offset: 0*32 + 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeForwardSweep tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_EvenRow) +{ + /* + * 2D serpentine traversal on a 4x4 grid: + * + * dim1 -> + * 0 1 2 3 + * +-->-->-->--+ row 0: forward (dim0=0 is even) + * +--<--<--<--+ row 1: backward (dim0=1 is odd) + * +-->-->-->--+ row 2: forward (dim0=2 is even) + * +--<--<--<--+ row 3: backward (dim0=3 is odd) + * dim0 + * + * At position (0, *): dim0 is even -> dim1 sweeps FORWARD + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward (outermost) + EXPECT_TRUE(sweep[Number<1>{}]); // dim 1: forward because dim0 position (0) is even +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_OddRow) +{ + /* + * Same 4x4 grid, but at row 1: + * + * +-->-->-->--+ row 0 + * +--<--<--<--+ row 1: dim0=1 is odd -> dim1 sweeps BACKWARD + * + * At position (1, *): dim0 is odd -> dim1 sweeps BACKWARD + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<1>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward + EXPECT_FALSE(sweep[Number<1>{}]); // dim 1: backward (dim0 position 1 is odd) +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_1D) +{ + /* + * 1D traversal: always forward regardless of position. + * + * 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<3>{}); + constexpr auto lengths = make_tuple(Number<8>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // 1D: only dimension, always forward +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeMoveOnDim tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerNotComplete) +{ + /* + * 2D grid with ordered_access_lengths = (3, 2): + * + * dim1: 0 1 + * dim0: + * 0 [*] . <-- at (0,0): dim1 hasn't reached end yet + * 1 . . + * 2 . . + * + * Rule: a dimension moves only when all faster-varying (higher-index) + * dimensions have completed their range. + * + * At (0, 0): + * dim0: dim1 is at 0, not at end (1). -> dim0 does NOT move. + * dim1: no higher dims to check, and 0 < 1. -> dim1 MOVES. + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{}); + constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths); + + EXPECT_FALSE(move[Number<0>{}]); // dim 0: inner dim NOT at end + EXPECT_TRUE(move[Number<1>{}]); // dim 1: can advance +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerComplete) +{ + /* + * Same grid, at position (0, 1): + * + * dim1: 0 1 + * dim0: + * 0 . [*] <-- at (0,1): dim1 at its end (1 == 2-1) + * 1 . . + * 2 . . + * + * At (0, 1): + * dim0: dim1 is at end (1 == 1). dim0 < 2. -> dim0 MOVES. + * dim1: at end. -> dim1 does NOT move. + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<1>{}); + constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{}); + constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths); + + EXPECT_TRUE(move[Number<0>{}]); // dim 0: inner dim at end, can advance + EXPECT_FALSE(move[Number<1>{}]); // dim 1: at its limit, cannot advance +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeDataIndex tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeDataIndex_ForwardSweep) +{ + /* + * 2D grid (4x3), both dims sweeping forward, identity order, scale=1: + * + * ordered_access_idx = (2, 1) + * forward_sweep = (true, true) + * dim_access_order = (0, 1) <-- identity + * scalar_per_access = (1, 1) <-- no scaling + * + * Forward: data_idx = ordered_idx = (2, 1) + * Reorder: identity -> (2, 1) + * Scale: * (1,1) -> (2, 1) + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<2>{}, Number<1>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<3>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + constexpr auto order = Sequence<0, 1>{}; + constexpr auto spa = Sequence<1, 1>{}; + + constexpr auto data_idx = Helper::ComputeDataIndex(idx, lengths, sweep, order, spa); + + EXPECT_EQ(data_idx[Number<0>{}], 2); + EXPECT_EQ(data_idx[Number<1>{}], 1); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeCoordinateResetStep tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeCoordinateResetStep_2D) +{ + /* + * SliceLengths = (4, 2), VectorDim = 1, ScalarPerVector = 2 + * DimAccessOrder = (0, 1) + * + * scalar_per_access = (1, 2) [only dim 1 is vectorized with width 2] + * access_lengths = (4, 1) [4/1=4, 2/2=1] + * + * The traversal visits 4 positions along dim 0, each accessing 2 elements: + * + * dim0=0: access [0,0..1] + * dim0=1: access [1,0..1] (backward sweep, but only 1 step on dim1) + * dim0=2: access [2,0..1] + * dim0=3: access [3,0..1] + * + * Final position: data_idx = (3, 0) * scalar_per_access = (3, 0) + * Reset step: -(3, 0) = (-3, 0) + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto reset = + Helper::ComputeCoordinateResetStep, 1, 2, Sequence<0, 1>>(); + + EXPECT_EQ(reset[Number<0>{}], -3); + EXPECT_EQ(reset[Number<1>{}], 0); +} + +// ============================================================================= +// VectorSizeLookupTable / VectorOffsetsLookupTable tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, VectorSizeLookupTable) +{ + /* + * Binary decomposition of vector widths into power-of-2 sub-loads: + * + * Width 0: (empty) -- no loads + * Width 1: {1} -- single 1-wide load + * Width 7: {4, 2, 1} -- 4+2+1 = 7 + * Width 8: {8} -- single 8-wide load + * Width 16: {16} -- single 16-wide load + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + using VecSize0 = tuple_element_t<0, Helper::VectorSizeLookupTable>; + using VecSize1 = tuple_element_t<1, Helper::VectorSizeLookupTable>; + using VecSize7 = tuple_element_t<7, Helper::VectorSizeLookupTable>; + using VecSize8 = tuple_element_t<8, Helper::VectorSizeLookupTable>; + using VecSize16 = tuple_element_t<16, Helper::VectorSizeLookupTable>; + + EXPECT_EQ(VecSize0::Size(), 0); + + EXPECT_EQ(VecSize1::Size(), 1); + EXPECT_EQ(VecSize1::At(0), 1); + + EXPECT_EQ(VecSize7::Size(), 3); + EXPECT_EQ(VecSize7::At(0), 4); // first sub-load: 4 elements + EXPECT_EQ(VecSize7::At(1), 2); // second sub-load: 2 elements + EXPECT_EQ(VecSize7::At(2), 1); // third sub-load: 1 element + + EXPECT_EQ(VecSize8::Size(), 1); + EXPECT_EQ(VecSize8::At(0), 8); + + EXPECT_EQ(VecSize16::Size(), 1); + EXPECT_EQ(VecSize16::At(0), 16); +} + +TEST(ThreadwiseTransferHelperSerpentine, VectorOffsetsLookupTable) +{ + /* + * Starting element offsets for each sub-load in the decomposition: + * + * Width 7 = {4, 2, 1}: + * |<--- 4 --->|<- 2 ->|1| + * offset 0 offset 4 offset 6 + * + * So offsets = {0, 4, 6} + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + using VecOff7 = tuple_element_t<7, Helper::VectorOffsetsLookupTable>; + + EXPECT_EQ(VecOff7::Size(), 3); + EXPECT_EQ(VecOff7::At(0), 0); // first sub-load starts at offset 0 + EXPECT_EQ(VecOff7::At(1), 4); // second sub-load starts at offset 4 + EXPECT_EQ(VecOff7::At(2), 6); // third sub-load starts at offset 6 +} + +// ============================================================================= +// ThreadwiseTransferHelper_SFC tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_SingleAccess) +{ + /* + * SliceLengths = (1, 1), ScalarPerAccess = (1, 1) + * Only 1 access position total -> already at origin, reset = (0, 0) + * + * [*] <-- single element, no movement needed + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 1>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], 0); + EXPECT_EQ(reset[Number<1>{}], 0); +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_RowMajor) +{ + /* + * Typical v6r1 scenario: 2D slice transfer with vectorized column access. + * + * SliceLengths = (4, 8) -- 4 rows, 8 columns + * DimAccessOrder = (0, 1) -- row-major traversal (rows change slowest) + * ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns + * + * access_lengths = SliceLengths / ScalarPerAccess = (4, 2) + * + * The SFC traverses in serpentine order through 4*2 = 8 access positions: + * + * col: 0..3 4..7 + * row 0: [0]-->[1] access 0 -> idx (0,0), access 1 -> idx (0,4) + * row 1: [3]<--[2] access 2 -> idx (1,4), access 3 -> idx (1,0) + * row 2: [4]-->[5] access 4 -> idx (2,0), access 5 -> idx (2,4) + * row 3: [7]<--[6] access 6 -> idx (3,4), access 7 -> idx (3,0) + * + * Last access (#7) lands at index (3, 0). + * Reset step = origin - last = (0,0) - (3,0) = (-3, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 4>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -3); // return 3 rows up + EXPECT_EQ(reset[Number<1>{}], 0); // column already at origin +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_ColMajor) +{ + /* + * Same 2D slice but column-major traversal order. + * + * SliceLengths = (4, 8) -- 4 rows, 8 columns + * DimAccessOrder = (1, 0) -- column-major (columns change slowest) + * ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns + * + * access_lengths = (4, 2) + * ordered_access_lengths = reorder_new2old((4,2), (1,0)) = (2, 4) + * (dim 1 is the "slow" outer dimension, dim 0 is the "fast" inner) + * + * Traversal (ordered dims are [col_block, row]): + * + * col_block: 0 1 + * row 0: [0] [7] + * row 1: [1] [6] + * row 2: [2] [5] + * row 3: [3] [4] + * + * Unordered indices (natural dim order): + * access 0 -> (row=0, col=0*4=0) + * access 3 -> (row=3, col=0) + * access 4 -> (row=3, col=1*4=4) (serpentine reversal in row) + * access 7 -> (row=0, col=4) + * + * Last access (#7) lands at index (0, 4). + * Reset step = (0,0) - (0,4) = (0, -4) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 4>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<1, 0>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], 0); // row already at origin + EXPECT_EQ(reset[Number<1>{}], -4); // return 4 columns left +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_3D) +{ + /* + * 3D slice transfer, modelling a batch x row x col tile as used in + * batched GEMM or attention kernels (v7r2/v7r3). + * + * SliceLengths = (2, 4, 8) -- 2 batches, 4 rows, 8 columns + * DimAccessOrder = (0, 1, 2) -- batch outermost, column innermost + * ScalarPerAccess = (1, 1, 8) -- 8-wide vector loads on columns + * + * access_lengths = (2, 4, 1) + * Total accesses = 2 * 4 * 1 = 8 + * + * Traversal within each batch is serpentine on rows, columns scalar: + * + * batch 0: + * row 0: [0] -- (0, 0, 0) + * row 1: [1] -- (0, 1, 0) + * row 2: [2] -- (0, 2, 0) + * row 3: [3] -- (0, 3, 0) + * + * batch 1: (serpentine reversal on rows) + * row 3: [4] -- (1, 3, 0) + * row 2: [5] -- (1, 2, 0) + * row 1: [6] -- (1, 1, 0) + * row 0: [7] -- (1, 0, 0) + * + * Last access (#7) lands at index (1, 0, 0). + * Reset step = (0,0,0) - (1,0,0) = (-1, 0, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 1, 8>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1, 2>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -1); // return 1 batch + EXPECT_EQ(reset[Number<1>{}], 0); // row already at origin (serpentine came back) + EXPECT_EQ(reset[Number<2>{}], 0); // column at origin (single access per row) +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_EvenInnerAccesses) +{ + /* + * When the number of accesses along the inner dimension is even, the + * serpentine traversal returns to the starting side on that dimension. + * + * SliceLengths = (4, 4) + * DimAccessOrder = (0, 1) + * ScalarPerAccess = (1, 2) -- 2-wide vector loads + * + * access_lengths = (4, 2) -- 2 accesses along cols (even) + * + * col: 0..1 2..3 + * row 0: [0]-->[1] access 0 -> (0,0), access 1 -> (0,2) + * row 1: [3]<--[2] access 2 -> (1,2), access 3 -> (1,0) + * row 2: [4]-->[5] access 4 -> (2,0), access 5 -> (2,2) + * row 3: [7]<--[6] access 6 -> (3,2), access 7 -> (3,0) + * + * Last access (#7) at (3, 0). Even number of column accesses (2) + * means the serpentine always returns to col=0 at the end of each row. + * Reset step = (0,0) - (3,0) = (-3, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 2>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -3); + EXPECT_EQ(reset[Number<1>{}], 0); // even inner accesses -> back at start column +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_OddInnerAccesses) +{ + /* + * When the number of accesses along the inner dimension is odd and the + * outer dimension is even, the serpentine returns to col=0. + * + * SliceLengths = (2, 6) + * DimAccessOrder = (0, 1) + * ScalarPerAccess = (1, 2) -- 2-wide vector loads + * + * access_lengths = (2, 3) -- 3 accesses along cols (odd!) + * + * col: 0..1 2..3 4..5 + * row 0: [0]-->[1]-->[2] access 0 -> (0,0), 1 -> (0,2), 2 -> (0,4) + * row 1: [5]<--[4]<--[3] access 3 -> (1,4), 4 -> (1,2), 5 -> (1,0) + * + * Last access (#5) at (1, 0). Even row count means serpentine reversal + * on the inner dim brings us back to col=0. + * Reset step = (0,0) - (1,0) = (-1, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 2>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -1); // return 1 row + EXPECT_EQ(reset[Number<1>{}], 0); // even outer accesses -> serpentine came back to col=0 +} + +// ============================================================================= +// Inheritance structure tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperInheritance, SerpentineIsDerivedFromBase) +{ + /* + * ThreadwiseTransferHelper_Base + * | + * +-- ThreadwiseTransferHelper_Serpentine <-- this relationship + * | + * +-- ThreadwiseTransferHelper_SFC + */ + static_assert( + std::is_base_of_v); +} + +TEST(ThreadwiseTransferHelperInheritance, SFCIsDerivedFromBase) +{ + /* + * ThreadwiseTransferHelper_Base + * | + * +-- ThreadwiseTransferHelper_Serpentine + * | + * +-- ThreadwiseTransferHelper_SFC <-- this relationship + */ + static_assert(std::is_base_of_v); +} + +TEST(ThreadwiseTransferHelperInheritance, SerpentineAndSFCAreNotRelated) +{ + /* + * Serpentine and SFC are siblings -- neither inherits from the other. + * + * ThreadwiseTransferHelper_Base + * | + * +-- Serpentine (NOT parent of SFC) + * | + * +-- SFC (NOT parent of Serpentine) + */ + static_assert( + !std::is_base_of_v); + static_assert( + !std::is_base_of_v); +} + +// ============================================================================= +// detail:: functor tests +// ============================================================================= + +TEST(DetailFunctors, LambdaScalarPerAccess) +{ + /* + * For VectorDim=1 and ScalarPerVector=8: + * + * dim: 0 1 2 + * result: 1 8 1 + * ^ ^ ^ + * | | +-- not the vector dim + * | +------ THE vector dim (returns ScalarPerVector) + * +---------- not the vector dim + */ + constexpr auto f = detail::lambda_scalar_per_access<1, 8>{}; + + EXPECT_EQ(f(0), 1); + EXPECT_EQ(f(1), 8); + EXPECT_EQ(f(2), 1); +} + +TEST(DetailFunctors, LambdaScalarStepInVector) +{ + /* + * For VectorDim=2: + * + * dim: 0 1 2 3 + * result: 0 0 1 0 + * ^ + * +-- THE vector dim (step = 1) + */ + constexpr auto f = detail::lambda_scalar_step_in_vector<2>{}; + + EXPECT_EQ(f(0), 0); + EXPECT_EQ(f(1), 0); + EXPECT_EQ(f(2), 1); + EXPECT_EQ(f(3), 0); +} + +TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_SameDim) +{ + /* + * Src and Dst both vectorize dim 1: + * SrcVectorDim=1, SrcScalarPerVector=4 + * DstVectorDim=1, DstScalarPerVector=8 + * + * dim: 0 1 2 + * result: 1 lcm(4,8) 1 + * = 8 + */ + constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<1, 4, 1, 8>{}; + + EXPECT_EQ(f(0), 1); + EXPECT_EQ(f(1), 8); // lcm(4, 8) = 8 + EXPECT_EQ(f(2), 1); +} + +TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_DifferentDims) +{ + /* + * Src vectorizes dim 0 (width 4), Dst vectorizes dim 2 (width 8): + * + * dim: 0 1 2 + * result: 4(src) 1 8(dst) + */ + constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<0, 4, 2, 8>{}; + + EXPECT_EQ(f(0), 4); // src vector dim + EXPECT_EQ(f(1), 1); // neither + EXPECT_EQ(f(2), 8); // dst vector dim +}