diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp index 5035fe23d0..c2b54e2ba3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -1,19 +1,28 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + namespace ck { -// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory -// and sometimes useless instructions: -// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument -// instead -// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same -// tensor coordinate instead -// 3. Don't use a pointer to VGPR buffer, use vector instead +/** + * @file threadwise_tensor_slice_transfer_util.hpp + * @brief Shared helper class hierarchy for threadwise tensor slice transfer variants. + * + * Provides a three-tier inheritance structure: + * + * - @ref ThreadwiseTransferHelper_Base -- generic coordinate/descriptor utilities + * - @ref ThreadwiseTransferHelper_Serpentine -- serpentine (snake/zigzag) traversal + * - @ref ThreadwiseTransferHelper_SFC -- SpaceFillingCurve traversal + */ namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor + +/** @brief Functor returning ScalarPerVector for dimension VectorDim, 1 otherwise. */ template struct lambda_scalar_per_access { @@ -23,6 +32,7 @@ struct lambda_scalar_per_access } }; +/** @brief Functor returning 1 for dimension VectorDim, 0 otherwise. */ template struct lambda_scalar_step_in_vector { @@ -32,8 +42,10 @@ struct lambda_scalar_step_in_vector } }; -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor +/** + * @brief Functor computing scalar-per-access for combined src/dst vector dimensions. + * Returns lcm when both src and dst share the same vector dimension. + */ template -struct lambda_wave_cluster_dimension +} // namespace detail + +/** + * @brief Base helper with methods shared by all threadwise transfer variants. + * + * Both ThreadwiseTransferHelper_Serpentine and ThreadwiseTransferHelper_SFC + * inherit from this class. Contains generic coordinate stepping, thread scratch + * descriptor construction, and compile-time index constants. + */ +struct ThreadwiseTransferHelper_Base { - __host__ __device__ constexpr auto operator()(index_t i) const + /** + * @name Compile-time index constants + * @{ + */ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + /** @} */ + + /** + * @brief Move the slice window by a step, optionally fusing coordinate reset. + * + * If the coordinate was not reset after RunRead/RunWrite, the reset step is + * added to the movement step to avoid a separate coordinate adjustment. + * + * @tparam ResetCoordinateAfterRun Whether the coordinate was already reset. + * @param desc Tensor descriptor. + * @param coord Tensor coordinate to move (modified in place). + * @param slice_origin_step_idx Step index for the slice window movement. + * @param get_reset_step Callable returning the coordinate reset step. + */ + template + __host__ __device__ static void MoveSliceWindow(const Desc& desc, + Coord& coord, + const StepIdx& slice_origin_step_idx, + GetCoordinateResetStepFunc get_reset_step) { - if((nDim - i) == 3) - return WaveNum; - else - return 1; + const auto adjusted_step_idx = ResetCoordinateAfterRun + ? slice_origin_step_idx + : slice_origin_step_idx + get_reset_step(); + + const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx); + + move_tensor_coordinate(desc, coord, adjusted_step); + } + + /** + * @brief Build the thread-local scratch tensor descriptor. + * + * Creates a transformed tensor descriptor where the vector dimension is merged + * with an additional dimension of size ScalarPerVector, enabling vector-typed + * access to the scratch buffer. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam VectorDim Which dimension is vectorized. + * @tparam ScalarPerVector_ Number of scalars per vector load/store. + * @return Transformed tensor descriptor for the thread scratch buffer. + */ + template + __host__ __device__ static constexpr auto ComputeThreadScratchDescriptor() + { + constexpr index_t nDim = SliceLengths::Size(); + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(access_lengths_and_vector_length); + + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == VectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(access_lengths_and_vector_length[i], + access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == VectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = generate_identity_sequences(); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + /** + * @brief Compute forward (+1) coordinate steps for each dimension. + * + * Returns a tuple of nDim coordinate steps, where step[i] moves by + * +scalar_per_access[i] in dimension i and 0 in all other dimensions. + * + * @param desc Tensor descriptor. + * @param scalar_per_access Per-dimension access widths (Sequence type). + */ + template + __host__ __device__ static constexpr auto + ComputeForwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + return generate_tuple( + [&](auto i) { + MultiIndex step_idx; + + static_for<0, nDim, 1>{}( + [&](auto j) { step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + /** + * @brief Compute backward (-1) coordinate steps for each dimension. + * + * Same as ComputeForwardSteps but with negated step values. + * + * @param desc Tensor descriptor. + * @param scalar_per_access Per-dimension access widths (Sequence type). + */ + template + __host__ __device__ static constexpr auto + ComputeBackwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + return generate_tuple( + [&](auto i) { + MultiIndex step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + /** + * @brief Create a tuple of default-constructed vector containers, one per data type. + * + * @tparam DataTypes Tuple of data types (e.g., Tuple). + * @tparam ScalarPerVector Number of scalars per vector. + * @return Tuple of vector_type_maker_t instances. + */ + template + __host__ __device__ static auto MakeVectorContainerTuple() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); } }; -} // namespace detail +/** + * @brief Serpentine (snake/zigzag) traversal helper. + * + * Provides methods for computing serpentine sweep directions, dimension movement + * decisions, and coordinate reset steps used by the v3r1 family of transfer classes. + * + * Used by: ThreadwiseTensorSliceTransfer_v3r1, v3r2, v3r1_gather, v3r1_dequant. + */ +struct ThreadwiseTransferHelper_Serpentine : ThreadwiseTransferHelper_Base +{ + /** + * @brief Binary decomposition of vector widths 0-16 into power-of-2 sub-load sizes. + * Index N gives the sequence of sub-load widths whose sum equals N. + * E.g. index 7 -> Sequence means loads of width 4, 2, 1. + */ + using VectorSizeLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + /** + * @brief Starting offsets for each sub-load in VectorSizeLookupTable. + * E.g. index 7 -> Sequence means offsets 0, 4, 6. + */ + using VectorOffsetsLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + /** + * @brief Compute serpentine sweep direction for each dimension. + * + * Determines whether each dimension should be traversed forward or backward + * based on the current position in the ordered access grid, implementing + * a zigzag (serpentine) traversal pattern. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @return Array of booleans: true = forward, false = backward. + */ + template + __host__ __device__ static constexpr auto + ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + constexpr index_t nDim = OrderedAccessLengths::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "ordered_access_idx and ordered_access_lengths must have same nDim"); + StaticallyIndexedArray_v2 forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, i, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + } + + /** + * @brief Determine which dimensions need coordinate movement at a given iteration. + * + * A dimension moves when it hasn't reached its end and all higher-priority + * (faster-varying) dimensions have completed their ranges. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @return Array of booleans: true = move coordinate on this dimension. + */ + template + __host__ __device__ static constexpr auto + ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths) + { + constexpr index_t nDim = OrderedAccessLengths::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "ordered_access_idx and ordered_access_lengths must have same nDim"); + StaticallyIndexedArray_v2 move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + + /** + * @brief Convert ordered access index to natural dimension order and apply scaling. + * + * @param ordered_access_idx Current position in the ordered access grid. + * @param ordered_access_lengths Size of the ordered access grid per dimension. + * @param forward_sweep Per-dimension sweep direction. + * @param dim_access_order Mapping from ordered to natural dimension indices. + * @param scalar_per_access Per-dimension access widths. + * @return MultiIndex in natural dimension order, scaled by scalar_per_access. + */ + template + __host__ __device__ static constexpr auto + ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx, + const OrderedAccessLengths& ordered_access_lengths, + const ForwardSweep& forward_sweep, + const DimAccessOrder& dim_access_order, + const ScalarPerAccess& scalar_per_access) + { + constexpr index_t nDim = ScalarPerAccess::Size(); + static_assert(OrderedAccessIdx::Size() == nDim, + "all arguments to ComputeDataIndex must have same nDim"); + MultiIndex ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; + } + + /** + * @brief Compute the coordinate step needed to return to the origin after traversal. + * + * Determines where the coordinate ends up after a full serpentine traversal, + * then returns the negated position as the reset step. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam VectorDim Which dimension is vectorized. + * @tparam ScalarPerVector_ Number of scalars per vector load/store. + * @tparam DimAccessOrder Compile-time sequence mapping ordered to natural dims. + * @return MultiIndex representing the step to reset the coordinate to the origin. + */ + template + __host__ __device__ static constexpr auto ComputeCoordinateResetStep() + { + constexpr index_t nDim = SliceLengths::Size(); + static_assert(DimAccessOrder::Size() == nDim, + "SliceLengths and DimAccessOrder must have same nDim"); + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + constexpr auto ordered_access_lengths_minus_1 = generate_tuple( + [&](auto i) { return Number{}; }, Number{}); + constexpr auto forward_sweep = + ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths); + + constexpr auto reset_step = [&]() { + MultiIndex ordered_idx; + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + auto data_idx = + container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; + + MultiIndex step; + static_for<0, nDim, 1>{}([&](auto i) { step(i) = -data_idx[i]; }); + return step; + }(); + + return reset_step; + } +}; + +/** + * @brief SpaceFillingCurve traversal helper. + * + * Provides coordinate reset computation using SpaceFillingCurve's GetStepBetween + * method, which computes the step from the last access position back to the origin. + * + * Used by: ThreadwiseTensorSliceTransfer v6r1, v6r1r2, v6r2, v6r3, v7r2, v7r3, + * v7r3_scatter. + */ +struct ThreadwiseTransferHelper_SFC : ThreadwiseTransferHelper_Base +{ + /** + * @brief Compute the coordinate reset step using SpaceFillingCurve traversal. + * + * @tparam SliceLengths Compile-time sequence of per-dimension slice lengths. + * @tparam DimAccessOrder Compile-time sequence defining dimension access order. + * @tparam ScalarPerAccess Compile-time sequence of per-dimension access widths. + * @return MultiIndex representing the step from last access position to origin. + */ + template + __host__ __device__ static constexpr auto ComputeSFCCoordinateResetStep() + { + using SFC = SpaceFillingCurve>; + + constexpr auto num_access = SFC::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SFC::Index{}; + } + else + { + return SFC::GetStepBetween(Number{}, Number<0>{}); + } + } +}; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 7b9d136068..8b0b35935f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -44,263 +44,63 @@ template struct ThreadwiseTensorSliceTransfer_v3r1 { + // ===================================================================== + // Private type aliases and constants + // ===================================================================== + private: + using Helper = ThreadwiseTransferHelper_Serpentine; + static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I10 = Number<10>{}; - static constexpr auto I12 = Number<12>{}; - static constexpr auto I13 = Number<13>{}; - static constexpr auto I14 = Number<14>{}; - static constexpr auto I16 = Number<16>{}; - - static constexpr index_t PackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); + static constexpr index_t PackedSize = is_same_v, pk_i4_t> ? 2 : 1; static constexpr auto SrcScalarPerVector = Number{}; static constexpr auto DstScalarPerVector = Number{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( - const SrcDesc& src_desc, - const Index& src_slice_origin, - const SrcElementwiseOperation& src_element_op, - const DstDesc& dst_desc, - const Index& dst_slice_origin, - const DstElementwiseOperation& dst_element_op) - : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), - dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), - src_element_op_(src_element_op), - dst_element_op_(dst_element_op) + // ===================================================================== + // Private implementation methods (must be declared before public methods + // that call them) + // ===================================================================== + __device__ static constexpr auto GetSrcCoordinateResetStep() { - if constexpr((packed_size_v) > 1) - { - static_assert(is_same_v, remove_cvref_t>, - "SrcData != DstData"); - - static_assert( - SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, - "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); - - static_assert(SrcVectorDim == DstVectorDim, - "Packed data type does not support transpose"); - } + return Helper::ComputeCoordinateResetStep(); } - __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + __device__ static constexpr auto GetDstCoordinateResetStep() { - src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + return Helper::ComputeCoordinateResetStep(); } - __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + return Helper:: + ComputeThreadScratchDescriptor(); } - template - __device__ void RunRead(const SrcDesc& src_desc, - const SrcBuffer& src_buf, - Number thread_scratch_id = Number{}) + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() { - static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or - SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, - "wrong!"); - - static_assert( - is_same, remove_cvref_t>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); - - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, - "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // make forward and backward steps - const auto src_forward_steps = ComputeForwardSteps(src_desc, src_scalar_per_access); - const auto src_backward_steps = ComputeBackwardSteps(src_desc, src_scalar_per_access); - - // loop over tensor and copy - static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = - ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - - // calculate src data index - constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); - - constexpr auto src_data_idx_seq = generate_sequence_v2( - [&](auto i) { return Number{}; }, Number{}); - - // maintain a container record is_src_valid, waiting for RunWrite use. - const bool is_src_valid = - coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); - src_oob_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, is_src_valid); - - using dst_vector_type = vector_type_maker_t; - using dst_vector_t = typename dst_vector_type::type; - dst_vector_type op_r_v; - - constexpr auto get_elem_op_vec_len = []() { - if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack8_invocable) - return math::min(8, SrcScalarPerVector); - } - else if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack4_invocable) - return math::min(4, SrcScalarPerVector); - } - else if constexpr(is_detected::value) - { - if constexpr(decltype(src_element_op_)::is_pack2_invocable) - return math::min(2, SrcScalarPerVector); - } - else - { - return 1; - } - }; - - constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); - - using src_elem_op_vec_t = typename vector_type::type; - using dst_elem_op_vec_t = typename vector_type::type; - - using VectorSizeLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; - using VectorOffsetsLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; - - static_for<0, tuple_element_t::Size(), 1>{}( - [&](auto v_idx) { - constexpr auto VectorLoadSize = - tuple_element_t::At(v_idx); - constexpr auto LoadOffset = - tuple_element_t::At(v_idx); - - using src_vector_container = vector_type_maker_t; - using src_vector_container_t = typename src_vector_container::type; - - src_vector_container src_vector = - src_vector_container{src_buf.template Get( - src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; - - static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { - // apply the src elementwise op and convert to DstData under the hood if - // needed - src_element_op_( - op_r_v.template AsType()(idx + LoadOffset), - src_vector.template AsType()[idx]); - }); - }); - - // copy data from src_vector_container into src_thread_scratch_ - src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, - op_r_v.template AsType()[I0]); - - constexpr auto move_on_dim = - ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - - // move src coord - static_for<0, nDim, 1>{}([&](auto i) { - if constexpr(move_on_dim[i]) - { - if constexpr(forward_sweep[i]) - { - move_tensor_coordinate( - src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); - } - else - { - move_tensor_coordinate( - src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); - } - } - }); - }); - - // move src coordinate back to slice origin (or not) - if constexpr(SrcResetCoordinateAfterRun) - { - const auto src_reset_step = - make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); - - move_tensor_coordinate(src_desc, src_coord_, src_reset_step); - } + return make_naive_tensor_descriptor_packed(src_access_lengths); } - template - __device__ constexpr auto - GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) + __device__ static constexpr auto GetDstThreadScratchDescriptor() { - using vector_t = typename vector_type_maker::type::type; - return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); + return Helper:: + ComputeThreadScratchDescriptor(); } template @@ -327,14 +127,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1 static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = - ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx, - ordered_src_access_lengths, - forward_sweep, - src_dim_access_order, - src_scalar_per_access); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -439,6 +239,194 @@ struct ThreadwiseTensorSliceTransfer_v3r1 #endif } + // ===================================================================== + // Public interface + // ===================================================================== + public: + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + if constexpr((packed_size_v) > 1) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); + + // calculate src data index + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // maintain a container record is_src_valid, waiting for RunWrite use. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, is_src_valid); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + else + { + return 1; + } + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, + tuple_element_t::Size(), + 1>{}([&](auto v_idx) { + constexpr auto VectorLoadSize = + tuple_element_t::At(v_idx); + constexpr auto LoadOffset = + tuple_element_t::At( + v_idx); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + src_vector_container src_vector = + src_vector_container{src_buf.template Get( + src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; + + static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if + // needed + src_element_op_(op_r_v.template AsType()(idx + LoadOffset), + src_vector.template AsType()[idx]); + }); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); + + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + template __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, @@ -470,21 +458,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); // make forward and backward steps - const auto dst_forward_steps = ComputeForwardSteps(dst_desc, dst_scalar_per_access); - const auto dst_backward_steps = ComputeBackwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { // judge move forward or move backward constexpr auto forward_sweep = - ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); // calculate dst data index - constexpr auto dst_data_idx = ComputeDataIndex(ordered_dst_access_idx, - ordered_dst_access_lengths, - forward_sweep, - dst_dim_access_order, - dst_scalar_per_access); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -510,10 +499,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1 dst_buf.template Set( dst_coord_.GetOffset() / PackedSize, is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); constexpr auto move_on_dim = - ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -543,21 +532,19 @@ struct ThreadwiseTensorSliceTransfer_v3r1 } } - __device__ static constexpr auto GetSrcCoordinateResetStep() + template + __device__ constexpr auto + GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) { - return ComputeCoordinateResetStep(); - } - - __device__ static constexpr auto GetDstCoordinateResetStep() - { - return ComputeCoordinateResetStep(); + using vector_t = typename vector_type_maker::type::type; + return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - MoveSliceWindow( + Helper::MoveSliceWindow( src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } @@ -565,252 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - MoveSliceWindow( + Helper::MoveSliceWindow( dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } - __device__ static constexpr auto GetSrcThreadScratchDescriptor() - { - return ComputeThreadScratchDescriptor(); - } - - __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() - { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - return make_naive_tensor_descriptor_packed(src_access_lengths); - } - - __device__ static constexpr auto GetDstThreadScratchDescriptor() - { - return ComputeThreadScratchDescriptor(); - } - - protected: - // Helper function to compute forward sweep pattern - // I.e. if we should move forward or backward in each of tensor's dimensions - template - __device__ static constexpr auto - ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths) - { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_access_idx[I0]; - - static_for<1, i, 1>{}( - [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - } - - // Compute which dimensions should have their coordinates updated during iteration - // A dimension moves when it hasn't reached its end and all higher priority dimensions - // have completed their ranges - template - __device__ static constexpr auto - ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths) - { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - } - - // Compute data index from ordered access index, converting back to natural order - template - __device__ static constexpr auto - ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx, - const OrderedAccessLengths& ordered_access_lengths, - const ForwardSweep& forward_sweep, - const DimAccessOrder& dim_access_order, - const ScalarPerAccess& scalar_per_access) - { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_access_idx[i] - : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access; - } - - // Compute forward coordinate steps for each dimension - template - __device__ static constexpr auto ComputeForwardSteps(const Desc& desc, - const ScalarPerAccess& scalar_per_access) - { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(desc, forward_step_idx); - }, - Number{}); - } - - // Compute backward coordinate steps for each dimension - template - __device__ static constexpr auto ComputeBackwardSteps(const Desc& desc, - const ScalarPerAccess& scalar_per_access) - { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(desc, backward_step_idx); - }, - Number{}); - } - - // Generic helper to compute thread scratch descriptor - template - __device__ static constexpr auto ComputeThreadScratchDescriptor() - { - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - constexpr auto access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == VectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(access_lengths_and_vector_length[i], - access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == VectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); - } - - // Generic helper to move slice window - template - __device__ static void MoveSliceWindow(const Desc& desc, - Coord& coord, - const Index& slice_origin_step_idx, - GetCoordinateResetStepFunc get_reset_step) - { - // if coord was not reset by RunRead/RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = ResetCoordinateAfterRun - ? slice_origin_step_idx - : slice_origin_step_idx + get_reset_step(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx); - - move_tensor_coordinate(desc, coord, adjusted_step); - } - - // Generic helper to compute coordinate reset step - template - __device__ static constexpr auto ComputeCoordinateResetStep() - { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto access_lengths = SliceLengths{} / scalar_per_access; - - constexpr auto dim_access_order = DimAccessOrder{}; - - constexpr auto ordered_access_lengths = - container_reorder_given_new2old(access_lengths, dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto ordered_access_lengths_minus_1 = generate_tuple( - [&](auto i) { return Number{}; }, Number{}); - constexpr auto forward_sweep = - ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths); - - // calculate data index after last iteration, if it has not being reset - constexpr auto data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - scalar_per_access; - }(); - - // - constexpr auto reset_data_step = [&]() { - Index reset_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; }); - - return reset_data_step_; - }(); - - return reset_data_step; - } - + // ===================================================================== + // Private data members + // ===================================================================== private: static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_oob_thread_scratch_desc_ = diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 2ddb34671a..7545c8c416 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -7,43 +7,12 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { -namespace detail { -// TODO: How to fix this? It uses an struct instead of lambda because lambda -// doesn't have constructor -template -struct lambda_scalar_per_access_for_src_and_dst_idle -{ - __host__ __device__ constexpr auto operator()(index_t i) const - { - if(i == SrcVectorDim && i == DstVectorDim) - { - return math::lcm(SrcScalarPerVector, DstScalarPerVector); - } - else if(i == SrcVectorDim) - { - return SrcScalarPerVector; - } - else if(i == DstVectorDim) - { - return DstScalarPerVector; - } - else - { - return 1; - } - } -}; - -} // namespace detail - // Assume: // 1. src_desc and dst_desc are not known at compile-time // 2. SrcBuffer and DstBuffer are DynamicBuffer @@ -84,12 +53,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant static constexpr index_t nDim = SliceLengths::Size(); using Index = MultiIndex; + using Helper = ThreadwiseTransferHelper_Serpentine; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -139,7 +108,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! SrcBuffer and SrcData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -150,66 +118,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -227,22 +150,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType( - src_data_idx_seq, src_vector_container.template AsType()[I0]); + src_data_idx_seq, + src_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -284,7 +196,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! ScaleBuffer and ScaleData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto scale_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -295,66 +206,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_scale_access_lengths = container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); - // make forward steps - const auto scale_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(scale_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto scale_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(scale_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto scale_forward_steps = + Helper::ComputeForwardSteps(scale_desc, scale_scalar_per_access); + const auto scale_backward_steps = + Helper::ComputeBackwardSteps(scale_desc, scale_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_scale_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_scale_access_idx, ordered_scale_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_scale_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate scale data index - constexpr auto scale_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] - : ordered_scale_access_lengths[i] - 1 - - ordered_scale_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * - scale_scalar_per_access; - }(); + constexpr auto scale_data_idx = Helper::ComputeDataIndex(ordered_scale_access_idx, + ordered_scale_access_lengths, + forward_sweep, + scale_dim_access_order, + scale_scalar_per_access); constexpr auto scale_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -372,23 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // copy data from scale_vector_container into scale_thread_scratch_ scale_thread_scratch_.template SetAsType( - scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + scale_data_idx_seq, + scale_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = - ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_scale_access_idx, ordered_scale_access_lengths); // move scale coord static_for<0, nDim, 1>{}([&](auto i) { @@ -409,17 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant } }); }); - - // don't need to move scale coordinate back to slice origin - /* - if constexpr(SrcResetCoordinateAfterRun) - { - const auto scale_reset_step = - make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); - - move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); - } - */ } template @@ -460,10 +304,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant detail::lambda_scalar_step_in_vector{}, Number{}); constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + detail::lambda_scalar_per_access_for_src_and_dst{}, Number{}); constexpr auto access_lengths = SliceLengths{} / scalar_per_access; @@ -504,10 +348,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant // Do fast numeric convert constexpr auto scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + detail::lambda_scalar_per_access_for_src_and_dst{}, Number{}); constexpr auto access_lengths = SliceLengths{} / scalar_per_access; @@ -528,15 +372,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant src_converted_thread_scratch_.template SetAsType( access_idx, - src_converted_vector_container.template AsType()[I0]); + src_converted_vector_container + .template AsType()[Helper::I0]); }); // Element-scale operation, expect packed multiplication static_ford{}([&](auto idx) { DstData dst_v; - constexpr auto scale_idx = Sequence{}; - // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), - // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + constexpr auto scale_idx = Sequence{}; src_element_op_(dst_v, src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); dst_thread_scratch_(idx) = dst_v; @@ -562,7 +405,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant "wrong! SrcBuffer or DstBuffer data type is wrong"); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -573,66 +415,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -660,22 +457,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant dst_buf.template Set( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -707,293 +492,52 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; + return Helper::ComputeCoordinateResetStep(); } __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + Helper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + Helper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetScaleThreadScratchDescriptor() { - - constexpr auto scale_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; - - constexpr auto scale_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(scale_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(scale_access_lengths_and_vector_length[i], - scale_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } private: @@ -1002,11 +546,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant decltype(GetScaleThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; - /* - template - struct ScaleThreadScratchDesc{}; - */ - // Registers, contain raw data loaded from global buffer using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + using Helper = ThreadwiseTransferHelper_Serpentine; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; - static constexpr auto I8 = Number<8>{}; - static constexpr auto I10 = Number<10>{}; - static constexpr auto I12 = Number<12>{}; - static constexpr auto I13 = Number<13>{}; - static constexpr auto I14 = Number<14>{}; - static constexpr auto I16 = Number<16>{}; - static constexpr index_t PackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -142,7 +126,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather "wrong! SrcBuffer and SrcData data type are inconsistent"); // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -156,66 +139,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // make forward steps - const auto src_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto src_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(src_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access); + const auto src_backward_steps = + Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -274,24 +214,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // copy data from src_vector_container into src_thread_scratch_ src_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, - op_r_v.template AsType()[I0]); + .template SetAsType( + src_data_idx_seq, op_r_v.template AsType()[Helper::I0]); + // Gather-specific: skip gather dimension during coordinate movement auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; + auto move_on_dim_ = + Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths); - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; - }); - move_on_dim_(i) &= i.value != ordered_gather_dim; - }); + static_for<0, nDim, 1>{}( + [&](auto i) { move_on_dim_(i) &= i.value != ordered_gather_dim; }); return move_on_dim_; }(); + // move src coord static_for<0, nDim, 1>{}([&](auto i) { if(move_on_dim[i]) @@ -351,38 +287,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather // loop over tensor and copy static_ford{}([&](auto ordered_src_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] - : ordered_src_access_lengths[i] - 1 - - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); + constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths, + forward_sweep, + src_dim_access_order, + src_scalar_per_access); constexpr auto src_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -501,7 +413,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather "wrong! SrcBuffer or DstBuffer data type is wrong"); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -512,66 +423,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_dst_access_lengths = container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - // make forward steps - const auto dst_forward_steps = generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, forward_step_idx); - }, - Number{}); - - // make backward steps - const auto dst_backward_steps = generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; - }); - - return make_tensor_coordinate_step(dst_desc, backward_step_idx); - }, - Number{}); + // make forward and backward steps + const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access); + const auto dst_backward_steps = + Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access); // loop over tensor and copy static_ford{}([&](auto ordered_dst_access_idx) { - // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths[i] - 1 - - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); + constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths, + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access); constexpr auto dst_data_idx_seq = generate_sequence_v2( [&](auto i) { return Number{}; }, Number{}); @@ -599,22 +465,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather dst_buf.template Set( dst_coord_.GetOffset() / PackedSize, is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = + Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -644,10 +498,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather } } + // Gather-specific: src coordinate reset zeroes the gather dimension __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); @@ -658,29 +511,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather constexpr auto ordered_src_access_lengths = container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; + constexpr auto ordered_access_lengths_minus_1 = generate_tuple( + [&](auto i) { return Number{}; }, Number{}); + constexpr auto forward_sweep = + Helper::ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_src_access_lengths); - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() constexpr auto src_data_idx = [&]() { - Index ordered_idx; + MultiIndex ordered_idx; static_for<0, nDim, 1>{}([&](auto i) { ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; @@ -690,9 +527,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather src_scalar_per_access; }(); - // + // Gather-specific: don't reset the gather dimension constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; + MultiIndex reset_src_data_step_; static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; @@ -705,137 +542,32 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access; - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = - SrcResetCoordinateAfterRun ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + Helper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by RunWrite(), then need to adjust the step here - const auto adjusted_step_idx = - DstResetCoordinateAfterRun ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + Helper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep); } __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() @@ -850,50 +582,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper:: + ComputeThreadScratchDescriptor(); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index 3c7291cca3..24fbd66be6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -7,10 +7,11 @@ #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor/static_tensor.hpp" #include "ck/utility/is_detected.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Assume: @@ -48,6 +49,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static constexpr index_t nSrc = SrcDescs::Size(); static constexpr index_t nDst = DstDescs::Size(); + using Helper = ThreadwiseTransferHelper_Serpentine; + // return a tuple of coordiantes for a tuple of tensor template {})); using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( const SrcDescs& src_descs, const StaticallyIndexedArray& src_slice_origins, @@ -101,7 +102,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2 Number thread_scratch_id = Number{}) { // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access constexpr auto src_scalar_per_access_tuple = generate_tuple( [&](auto src_i) { return generate_sequence( @@ -129,40 +129,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - // make forward steps + // make forward and backward steps const auto src_forward_steps_tuple = generate_tuple( [&](auto src_i) { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = - (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; - }); - - return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); - }, - Number{}); + return Helper::ComputeForwardSteps(src_descs.At(src_i), + src_scalar_per_access_tuple.At(src_i)); }, Number{}); - // make backward steps const auto src_backward_steps_tuple = generate_tuple( [&](auto src_i) { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? -src_scalar_per_access_tuple.At(src_i)[i] - : 0; - }); - - return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); - }, - Number{}); + return Helper::ComputeBackwardSteps(src_descs.At(src_i), + src_scalar_per_access_tuple.At(src_i)); }, Number{}); @@ -171,39 +149,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_ford>{}( [&](auto ordered_src_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths_tuple[j] + - ordered_src_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = Helper::ComputeForwardSweep( + ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i)); // calculate src data index - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_src_access_idx[i] - : ordered_src_access_lengths_tuple.At(src_i)[i] - - 1 - ordered_src_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access_tuple.At(src_i); - }(); + constexpr auto src_data_idx = + Helper::ComputeDataIndex(ordered_src_access_idx, + ordered_src_access_lengths_tuple.At(src_i), + forward_sweep, + src_dim_access_order, + src_scalar_per_access_tuple.At(src_i)); constexpr auto src_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -227,24 +182,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2 .At(src_i) .template SetAsType( src_data_idx_seq, - src_vector_container.template AsType()[I0]); + src_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_src_access_idx[i] < - ordered_src_access_lengths_tuple.At(src_i)[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_src_access_idx[j] == - ordered_src_access_lengths_tuple.At(src_i)[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = Helper::ComputeMoveOnDim( + ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i)); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -287,18 +228,30 @@ struct ThreadwiseTensorSliceTransfer_v3r2 { // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE // (it requires to add Elementwise support in transpose_vectors) - static_ford{}([&](auto idx) { - const auto src_data_refs = generate_tie( - [&](auto src_i) -> const auto& { - return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; - }, - Number{}); + if constexpr(nSrc == 1 && nDst == 1) + { + // Fast path: direct element transfer, no generate_tie/unpack2 overhead + static_ford{}([&](auto idx) { + element_op_(dst_thread_scratch_tuple_.At(Number<0>{})(idx), + src_thread_scratch_tuple_[thread_scratch_id].At(Number<0>{})[idx]); + }); + } + else + { + // General path: use generate_tie + unpack2 for multi-src/dst + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); - auto dst_data_refs = generate_tie( - [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, - Number{}); - unpack2(element_op_, dst_data_refs, src_data_refs); - }); + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } } template @@ -311,7 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2 TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); // src scalar per access on each dim - // TODO: don't use this constexpr auto dst_scalar_per_access_tuple = generate_tuple( [&](auto dst_i) { return generate_sequence( @@ -334,40 +286,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }, Number{}); - // make forward steps + // make forward and backward steps const auto dst_forward_steps_tuple = generate_tuple( [&](auto dst_i) { - return generate_tuple( - [&](auto i) { - Index forward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - forward_step_idx(j) = - (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; - }); - - return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); - }, - Number{}); + return Helper::ComputeForwardSteps(dst_descs.At(dst_i), + dst_scalar_per_access_tuple.At(dst_i)); }, Number{}); - // make backward steps const auto dst_backward_steps_tuple = generate_tuple( [&](auto dst_i) { - return generate_tuple( - [&](auto i) { - Index backward_step_idx; - - static_for<0, nDim, 1>{}([&](auto j) { - backward_step_idx(j) = (i.value == j.value) - ? -dst_scalar_per_access_tuple.At(dst_i)[i] - : 0; - }); - - return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); - }, - Number{}); + return Helper::ComputeBackwardSteps(dst_descs.At(dst_i), + dst_scalar_per_access_tuple.At(dst_i)); }, Number{}); @@ -376,39 +306,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2 static_ford>{}( [&](auto ordered_dst_access_idx) { // judge move forward or move backward - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_idx[I0]; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + - ordered_dst_access_idx[j]; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); + constexpr auto forward_sweep = Helper::ComputeForwardSweep( + ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i)); // calculate dst data index - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] - ? ordered_dst_access_idx[i] - : ordered_dst_access_lengths_tuple.At(dst_i)[i] - - 1 - ordered_dst_access_idx[i]; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access_tuple.At(dst_i); - }(); + constexpr auto dst_data_idx = + Helper::ComputeDataIndex(ordered_dst_access_idx, + ordered_dst_access_lengths_tuple.At(dst_i), + forward_sweep, + dst_dim_access_order, + dst_scalar_per_access_tuple.At(dst_i)); constexpr auto dst_data_idx_seq = generate_sequence_v2([&](auto i) { return Number{}; }, @@ -434,24 +341,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2 dst_bufs.At(dst_i).template Update( dst_coords_.At(dst_i).GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[Helper::I0]); - constexpr auto move_on_dim = [&]() constexpr { - StaticallyIndexedArray move_on_dim_; - - static_for<0, nDim, 1>{}([&](auto i) { - move_on_dim_(i) = ordered_dst_access_idx[i] < - ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; - - static_for{}([&](auto j) { - move_on_dim_(i) &= - ordered_dst_access_idx[j] == - ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; - }); - }); - - return move_on_dim_; - }(); + constexpr auto move_on_dim = Helper::ComputeMoveOnDim( + ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i)); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { @@ -491,121 +384,19 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ static constexpr auto GetSrcCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_dim_access_order = SrcDimAccessOrder{}; - - constexpr auto ordered_src_access_lengths = - container_reorder_given_new2old(src_access_lengths, src_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_src_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate src data index after last iteration in RunRead(), if it has not being reset by - // RunRead() - constexpr auto src_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * - src_scalar_per_access; - }(); - - // - constexpr auto reset_src_data_step = [&]() { - Index reset_src_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); - - return reset_src_data_step_; - }(); - - return reset_src_data_step; + return Helper::ComputeCoordinateResetStep(); } template __device__ static constexpr auto GetDstCoordinateResetStep() { - // scalar per access on each dim - // TODO: don't use lambda_scalar_per_access - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_dim_access_order = DstDimAccessOrder{}; - - constexpr auto ordered_dst_access_lengths = - container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); - - // judge move forward or move backward during the last iteration - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto i) { - index_t tmp = ordered_dst_access_lengths[I0] - 1; - - static_for<1, i, 1>{}([&](auto j) { - tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; - }); - - forward_sweep_(i) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate dst data index after last iteration in RunWrite(), if it has not being reset by - // RunWrite() - constexpr auto dst_data_idx = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto i) { - ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; - }); - - return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * - dst_scalar_per_access.At(dst_i); - }(); - - // - constexpr auto reset_dst_data_step = [&]() { - Index reset_dst_data_step_; - - static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); - - return reset_dst_data_step_; - }(); - - return reset_dst_data_step; + return Helper::ComputeCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason @@ -649,103 +440,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2 template __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - constexpr auto src_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = - container_push_back(sequence_to_tuple_of_number(src_access_lengths), - Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper::ComputeThreadScratchDescriptor(); } template __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - constexpr auto dst_scalar_per_access = generate_sequence( - detail::lambda_scalar_per_access{}, - Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = - container_push_back(sequence_to_tuple_of_number(dst_access_lengths), - Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return Helper::ComputeThreadScratchDescriptor(); } __device__ static constexpr auto MakeSrcThreadScratchTuple() diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp index 74a964ddd8..45b638c842 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -42,8 +42,6 @@ struct ThreadwiseTensorSliceTransfer_v4r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index bce2d453dc..5d14d66eb3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -44,9 +44,6 @@ struct ThreadwiseTensorSliceTransfer_v5r1 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); - using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); - __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 2e255e2500..fc0ec9128d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -40,11 +42,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, const Index& src_slice_origin, const DstDesc& dst_desc, @@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = SrcResetCoordinateAfterRun - ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp index 43d4148dab..711f693f6f 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -39,11 +41,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2( const SrcDesc& src_desc, const Index& src_slice_origin, @@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& src_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = SrcResetCoordinateAfterRun - ? src_slice_origin_step_idx - : src_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); - - move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep); } // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp index f036bc4312..f7e5aa3adf 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -43,12 +45,12 @@ struct ThreadwiseTensorSliceTransfer_v6r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, const Index& src0_slice_origin, const Src1Desc& src1_desc, @@ -141,7 +143,7 @@ struct ThreadwiseTensorSliceTransfer_v6r2 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -187,67 +189,30 @@ struct ThreadwiseTensorSliceTransfer_v6r2 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& src0_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src0ResetCoordinateAfterRun - ? src0_slice_origin_step_idx - : src0_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); - - move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& src1_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src1ResetCoordinateAfterRun - ? src1_slice_origin_step_idx - : src1_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); - - move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp index 7d53c1ac0d..79a6b5d3aa 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -8,6 +8,8 @@ #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + namespace ck { // Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory @@ -46,13 +48,13 @@ struct ThreadwiseTensorSliceTransfer_v6r3 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); - static constexpr auto I0 = Number<0>{}; - __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, const Index& src0_slice_origin, const Src1Desc& src1_desc, @@ -165,7 +167,7 @@ struct ThreadwiseTensorSliceTransfer_v6r3 dst_buf.template Update( dst_coord_.GetOffset(), is_dst_valid, - dst_vector_container.template AsType()[I0]); + dst_vector_container.template AsType()[SFCHelper::I0]); // move coordinate if constexpr(idx_1d.value != num_access - 1) @@ -221,82 +223,37 @@ struct ThreadwiseTensorSliceTransfer_v6r3 constexpr auto scalar_per_access = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); - using SpaceFillingCurve = SpaceFillingCurve>; - - constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); - if constexpr(num_access == 0) - { - return typename SpaceFillingCurve::Index{}; - } - else - { - constexpr auto reset_step = - SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - - return reset_step; - } + return SFCHelper::ComputeSFCCoordinateResetStep(); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& src0_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src0ResetCoordinateAfterRun - ? src0_slice_origin_step_idx - : src0_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); - - move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& src1_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src1ResetCoordinateAfterRun - ? src1_slice_origin_step_idx - : src1_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); - - move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep); } - // src_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& src2_slice_origin_step_idx) { - // if src coord was not reset by RunRead(), then need to adjust the step here - const auto adjusted_step_idx = Src2ResetCoordinateAfterRun - ? src2_slice_origin_step_idx - : src2_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); - - move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + src2_desc, src2_coord_, src2_slice_origin_step_idx, GetCoordinateResetStep); } - // dst_slice_origin_step_idx need to be known at compile-time, for performance reason __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& dst_slice_origin_step_idx) { - // if dst coord was not reset by Run(), then need to adjust the step here - const auto adjusted_step_idx = DstResetCoordinateAfterRun - ? dst_slice_origin_step_idx - : dst_slice_origin_step_idx + GetCoordinateResetStep(); - - // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); - - move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + SFCHelper::MoveSliceWindow( + dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep); } private: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp index 6326f6cbda..64f9ac2243 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -55,6 +55,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + // return a tuple of coordiantes for a tuple of tensor template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -473,98 +465,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp index b4ee81697e..c4fad23f70 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -56,6 +56,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3 using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; + // return a tuple of coordiantes for a tuple of tensor template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -615,100 +607,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3 __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 732922c157..45bd6f3f8e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -63,6 +63,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static constexpr index_t nDst = DstDescs::Size(); using Index = MultiIndex; + using SFCHelper = ThreadwiseTransferHelper_SFC; static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); // return a tuple of coordiantes for a tuple of tensor @@ -134,17 +135,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter template __device__ static auto generate_vectors() { - auto data_types = DataTypes{}; - - constexpr index_t num = data_types.Size(); - - return generate_tuple( - [&](auto i) { - using DataType = remove_cvref_t; - - return vector_type_maker_t{}; - }, - Number{}); + return SFCHelper::MakeVectorContainerTuple(); } // SrcDescs: Tuple @@ -506,100 +497,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter __device__ static constexpr auto GetSrcThreadScratchDescriptor() { - // constexpr auto src_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; - - constexpr auto src_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(src_access_lengths), Number{}); - - // 1st stage of transforms - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(src_access_lengths_and_vector_length[i], - src_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(src_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == SrcVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } __device__ static constexpr auto GetDstThreadScratchDescriptor() { - // 1st stage of transforms - // constexpr auto dst_scalar_per_access = generate_sequence( - // detail::lambda_scalar_per_access{}, - // Number{}); - - constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; - - constexpr auto dst_access_lengths_and_vector_length = container_push_back( - sequence_to_tuple_of_number(dst_access_lengths), Number{}); - - constexpr auto desc0 = - make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); - - // 2nd stage of transforms - constexpr auto transforms = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return make_merge_transform_v3_division_mod( - make_tuple(dst_access_lengths_and_vector_length[i], - dst_access_lengths_and_vector_length[Number{}])); - } - else - { - return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); - } - }, - Number{}); - - constexpr auto low_dim_idss = generate_tuple( - [&](auto i) { - if constexpr(i == DstVectorDim) - { - return Sequence{}; - } - else - { - return Sequence{}; - } - }, - Number{}); - - constexpr auto up_dim_idss = generate_identity_sequences(); - - return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + return SFCHelper:: + ComputeThreadScratchDescriptor(); } // src_slice_origin_step_idx need to be known at compile-time, for performance reason diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b0b5f1c82f..017391549a 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -255,6 +255,7 @@ add_compile_options(-Wno-c++20-extensions) add_subdirectory(ck_tile) add_subdirectory(magic_number_division) add_subdirectory(space_filling_curve) +add_subdirectory(threadwise_transfer_helper) add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) diff --git a/test/threadwise_transfer_helper/CMakeLists.txt b/test/threadwise_transfer_helper/CMakeLists.txt new file mode 100644 index 0000000000..d157f19500 --- /dev/null +++ b/test/threadwise_transfer_helper/CMakeLists.txt @@ -0,0 +1,4 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +add_gtest_executable(test_threadwise_transfer_helper test_threadwise_transfer_helper.cpp) diff --git a/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp b/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp new file mode 100644 index 0000000000..0033fb0db8 --- /dev/null +++ b/test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp @@ -0,0 +1,748 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +using namespace ck; + +// ============================================================================= +// ThreadwiseTransferHelper_Base tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperBase, CompileTimeConstants) +{ + EXPECT_EQ(ThreadwiseTransferHelper_Base::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I1.value, 1); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I2.value, 2); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I4.value, 4); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I8.value, 8); + EXPECT_EQ(ThreadwiseTransferHelper_Base::I16.value, 16); +} + +TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySerpentine) +{ + // Serpentine inherits all constants from Base via public inheritance. + EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I16.value, 16); +} + +TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySFC) +{ + // SFC inherits all constants from Base via public inheritance. + EXPECT_EQ(ThreadwiseTransferHelper_SFC::I0.value, 0); + EXPECT_EQ(ThreadwiseTransferHelper_SFC::I16.value, 16); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Base::MoveSliceWindow tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetAlreadyDone) +{ + /* + * Scenario: v3r1's MoveSrcSliceWindow after RunRead has already reset + * the coordinate back to the slice origin (SrcResetCoordinateAfterRun=true). + * + * 2D packed tensor (4 rows x 8 columns), modelling a tile transfer: + * + * col: 0 1 2 3 4 5 6 7 + * row 0: [*] . . . . . . . <-- start at (0,0), offset=0 + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * Step = (1, 0): move one row down. + * Reset step = (-3, 0): would move 3 rows up (irrelevant here). + * + * Since ResetCoordinateAfterRun=true, the reset step is NOT fused + * into the movement. The coordinate simply moves by the step alone. + * + * Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(1, 0); + + auto get_reset_step = []() { return make_multi_index(-3, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // Coordinate moved by step only: (0,0) -> (1,0) + // Offset in row-major packed layout: 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetFused) +{ + /* + * Scenario: v3r1's MoveSrcSliceWindow when RunRead did NOT reset + * the coordinate (SrcResetCoordinateAfterRun=false). This is the + * optimization path where MoveSliceWindow fuses the reset step + * with the movement step to save a separate coordinate adjustment. + * + * Same 2D packed tensor (4 rows x 8 columns): + * + * col: 0 1 2 3 4 5 6 7 + * row 0: [*] . . . . . . . <-- start at (0,0), offset=0 + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * Step = (2, 0): move two rows down. + * Reset step = (-1, 0): move one row up (e.g., undo traversal overshoot). + * + * Since ResetCoordinateAfterRun=false, MoveSliceWindow adds the + * reset step to the movement step before applying: + * adjusted_step = step + reset = (2,0) + (-1,0) = (1,0) + * + * Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(2, 0); + + auto get_reset_step = []() { return make_multi_index(-1, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // adjusted_step = (2,0) + (-1,0) = (1,0) + // Offset: 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_3D_ResetFused) +{ + /* + * Scenario: 3D packed tensor (2 x 4 x 8), modelling a typical GEMM + * intermediate buffer with SliceLengths = (batch, row, col). + * + * Layout (batch=0 shown, row-major packed): + * + * batch 0: + * col: 0 1 2 3 4 5 6 7 + * row 0: . . . . . . . . + * row 1: . . . . . . . . + * row 2: . . . . . . . . + * row 3: . . . . . . . . + * + * batch 1: (same structure, offset += 4*8 = 32) + * + * Start at (0, 0, 0), offset=0. + * + * Step = (0, 2, 0): move 2 rows down within the same batch. + * Reset step = (0, -1, 0): undo 1 row of traversal overshoot. + * + * ResetCoordinateAfterRun=false, so steps are fused: + * adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0) + * + * Expected: (0,0,0) + (0,1,0) = (0,1,0) + * Offset in packed layout: 0*(4*8) + 1*8 + 0 = 8 + */ + using Helper = ThreadwiseTransferHelper_Base; + + constexpr auto desc = + make_naive_tensor_descriptor_packed(make_tuple(Number<2>{}, Number<4>{}, Number<8>{})); + + auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0, 0)); + + EXPECT_EQ(coord.GetOffset(), 0); + + const auto step_idx = make_multi_index(0, 2, 0); + + auto get_reset_step = []() { return make_multi_index(0, -1, 0); }; + + Helper::MoveSliceWindow( + desc, coord, step_idx, get_reset_step); + + // adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0) + // Offset: 0*32 + 1*8 + 0 = 8 + EXPECT_EQ(coord.GetOffset(), 8); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeForwardSweep tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_EvenRow) +{ + /* + * 2D serpentine traversal on a 4x4 grid: + * + * dim1 -> + * 0 1 2 3 + * +-->-->-->--+ row 0: forward (dim0=0 is even) + * +--<--<--<--+ row 1: backward (dim0=1 is odd) + * +-->-->-->--+ row 2: forward (dim0=2 is even) + * +--<--<--<--+ row 3: backward (dim0=3 is odd) + * dim0 + * + * At position (0, *): dim0 is even -> dim1 sweeps FORWARD + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward (outermost) + EXPECT_TRUE(sweep[Number<1>{}]); // dim 1: forward because dim0 position (0) is even +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_OddRow) +{ + /* + * Same 4x4 grid, but at row 1: + * + * +-->-->-->--+ row 0 + * +--<--<--<--+ row 1: dim0=1 is odd -> dim1 sweeps BACKWARD + * + * At position (1, *): dim0 is odd -> dim1 sweeps BACKWARD + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<1>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward + EXPECT_FALSE(sweep[Number<1>{}]); // dim 1: backward (dim0 position 1 is odd) +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_1D) +{ + /* + * 1D traversal: always forward regardless of position. + * + * 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<3>{}); + constexpr auto lengths = make_tuple(Number<8>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + + EXPECT_TRUE(sweep[Number<0>{}]); // 1D: only dimension, always forward +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeMoveOnDim tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerNotComplete) +{ + /* + * 2D grid with ordered_access_lengths = (3, 2): + * + * dim1: 0 1 + * dim0: + * 0 [*] . <-- at (0,0): dim1 hasn't reached end yet + * 1 . . + * 2 . . + * + * Rule: a dimension moves only when all faster-varying (higher-index) + * dimensions have completed their range. + * + * At (0, 0): + * dim0: dim1 is at 0, not at end (1). -> dim0 does NOT move. + * dim1: no higher dims to check, and 0 < 1. -> dim1 MOVES. + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<0>{}); + constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{}); + constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths); + + EXPECT_FALSE(move[Number<0>{}]); // dim 0: inner dim NOT at end + EXPECT_TRUE(move[Number<1>{}]); // dim 1: can advance +} + +TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerComplete) +{ + /* + * Same grid, at position (0, 1): + * + * dim1: 0 1 + * dim0: + * 0 . [*] <-- at (0,1): dim1 at its end (1 == 2-1) + * 1 . . + * 2 . . + * + * At (0, 1): + * dim0: dim1 is at end (1 == 1). dim0 < 2. -> dim0 MOVES. + * dim1: at end. -> dim1 does NOT move. + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<0>{}, Number<1>{}); + constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{}); + constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths); + + EXPECT_TRUE(move[Number<0>{}]); // dim 0: inner dim at end, can advance + EXPECT_FALSE(move[Number<1>{}]); // dim 1: at its limit, cannot advance +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeDataIndex tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeDataIndex_ForwardSweep) +{ + /* + * 2D grid (4x3), both dims sweeping forward, identity order, scale=1: + * + * ordered_access_idx = (2, 1) + * forward_sweep = (true, true) + * dim_access_order = (0, 1) <-- identity + * scalar_per_access = (1, 1) <-- no scaling + * + * Forward: data_idx = ordered_idx = (2, 1) + * Reorder: identity -> (2, 1) + * Scale: * (1,1) -> (2, 1) + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto idx = make_tuple(Number<2>{}, Number<1>{}); + constexpr auto lengths = make_tuple(Number<4>{}, Number<3>{}); + constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths); + constexpr auto order = Sequence<0, 1>{}; + constexpr auto spa = Sequence<1, 1>{}; + + constexpr auto data_idx = Helper::ComputeDataIndex(idx, lengths, sweep, order, spa); + + EXPECT_EQ(data_idx[Number<0>{}], 2); + EXPECT_EQ(data_idx[Number<1>{}], 1); +} + +// ============================================================================= +// ThreadwiseTransferHelper_Serpentine::ComputeCoordinateResetStep tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, ComputeCoordinateResetStep_2D) +{ + /* + * SliceLengths = (4, 2), VectorDim = 1, ScalarPerVector = 2 + * DimAccessOrder = (0, 1) + * + * scalar_per_access = (1, 2) [only dim 1 is vectorized with width 2] + * access_lengths = (4, 1) [4/1=4, 2/2=1] + * + * The traversal visits 4 positions along dim 0, each accessing 2 elements: + * + * dim0=0: access [0,0..1] + * dim0=1: access [1,0..1] (backward sweep, but only 1 step on dim1) + * dim0=2: access [2,0..1] + * dim0=3: access [3,0..1] + * + * Final position: data_idx = (3, 0) * scalar_per_access = (3, 0) + * Reset step: -(3, 0) = (-3, 0) + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + constexpr auto reset = + Helper::ComputeCoordinateResetStep, 1, 2, Sequence<0, 1>>(); + + EXPECT_EQ(reset[Number<0>{}], -3); + EXPECT_EQ(reset[Number<1>{}], 0); +} + +// ============================================================================= +// VectorSizeLookupTable / VectorOffsetsLookupTable tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSerpentine, VectorSizeLookupTable) +{ + /* + * Binary decomposition of vector widths into power-of-2 sub-loads: + * + * Width 0: (empty) -- no loads + * Width 1: {1} -- single 1-wide load + * Width 7: {4, 2, 1} -- 4+2+1 = 7 + * Width 8: {8} -- single 8-wide load + * Width 16: {16} -- single 16-wide load + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + + using VecSize0 = tuple_element_t<0, Helper::VectorSizeLookupTable>; + using VecSize1 = tuple_element_t<1, Helper::VectorSizeLookupTable>; + using VecSize7 = tuple_element_t<7, Helper::VectorSizeLookupTable>; + using VecSize8 = tuple_element_t<8, Helper::VectorSizeLookupTable>; + using VecSize16 = tuple_element_t<16, Helper::VectorSizeLookupTable>; + + EXPECT_EQ(VecSize0::Size(), 0); + + EXPECT_EQ(VecSize1::Size(), 1); + EXPECT_EQ(VecSize1::At(0), 1); + + EXPECT_EQ(VecSize7::Size(), 3); + EXPECT_EQ(VecSize7::At(0), 4); // first sub-load: 4 elements + EXPECT_EQ(VecSize7::At(1), 2); // second sub-load: 2 elements + EXPECT_EQ(VecSize7::At(2), 1); // third sub-load: 1 element + + EXPECT_EQ(VecSize8::Size(), 1); + EXPECT_EQ(VecSize8::At(0), 8); + + EXPECT_EQ(VecSize16::Size(), 1); + EXPECT_EQ(VecSize16::At(0), 16); +} + +TEST(ThreadwiseTransferHelperSerpentine, VectorOffsetsLookupTable) +{ + /* + * Starting element offsets for each sub-load in the decomposition: + * + * Width 7 = {4, 2, 1}: + * |<--- 4 --->|<- 2 ->|1| + * offset 0 offset 4 offset 6 + * + * So offsets = {0, 4, 6} + */ + using Helper = ThreadwiseTransferHelper_Serpentine; + using VecOff7 = tuple_element_t<7, Helper::VectorOffsetsLookupTable>; + + EXPECT_EQ(VecOff7::Size(), 3); + EXPECT_EQ(VecOff7::At(0), 0); // first sub-load starts at offset 0 + EXPECT_EQ(VecOff7::At(1), 4); // second sub-load starts at offset 4 + EXPECT_EQ(VecOff7::At(2), 6); // third sub-load starts at offset 6 +} + +// ============================================================================= +// ThreadwiseTransferHelper_SFC tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_SingleAccess) +{ + /* + * SliceLengths = (1, 1), ScalarPerAccess = (1, 1) + * Only 1 access position total -> already at origin, reset = (0, 0) + * + * [*] <-- single element, no movement needed + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 1>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], 0); + EXPECT_EQ(reset[Number<1>{}], 0); +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_RowMajor) +{ + /* + * Typical v6r1 scenario: 2D slice transfer with vectorized column access. + * + * SliceLengths = (4, 8) -- 4 rows, 8 columns + * DimAccessOrder = (0, 1) -- row-major traversal (rows change slowest) + * ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns + * + * access_lengths = SliceLengths / ScalarPerAccess = (4, 2) + * + * The SFC traverses in serpentine order through 4*2 = 8 access positions: + * + * col: 0..3 4..7 + * row 0: [0]-->[1] access 0 -> idx (0,0), access 1 -> idx (0,4) + * row 1: [3]<--[2] access 2 -> idx (1,4), access 3 -> idx (1,0) + * row 2: [4]-->[5] access 4 -> idx (2,0), access 5 -> idx (2,4) + * row 3: [7]<--[6] access 6 -> idx (3,4), access 7 -> idx (3,0) + * + * Last access (#7) lands at index (3, 0). + * Reset step = origin - last = (0,0) - (3,0) = (-3, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 4>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -3); // return 3 rows up + EXPECT_EQ(reset[Number<1>{}], 0); // column already at origin +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_ColMajor) +{ + /* + * Same 2D slice but column-major traversal order. + * + * SliceLengths = (4, 8) -- 4 rows, 8 columns + * DimAccessOrder = (1, 0) -- column-major (columns change slowest) + * ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns + * + * access_lengths = (4, 2) + * ordered_access_lengths = reorder_new2old((4,2), (1,0)) = (2, 4) + * (dim 1 is the "slow" outer dimension, dim 0 is the "fast" inner) + * + * Traversal (ordered dims are [col_block, row]): + * + * col_block: 0 1 + * row 0: [0] [7] + * row 1: [1] [6] + * row 2: [2] [5] + * row 3: [3] [4] + * + * Unordered indices (natural dim order): + * access 0 -> (row=0, col=0*4=0) + * access 3 -> (row=3, col=0) + * access 4 -> (row=3, col=1*4=4) (serpentine reversal in row) + * access 7 -> (row=0, col=4) + * + * Last access (#7) lands at index (0, 4). + * Reset step = (0,0) - (0,4) = (0, -4) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 4>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<1, 0>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], 0); // row already at origin + EXPECT_EQ(reset[Number<1>{}], -4); // return 4 columns left +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_3D) +{ + /* + * 3D slice transfer, modelling a batch x row x col tile as used in + * batched GEMM or attention kernels (v7r2/v7r3). + * + * SliceLengths = (2, 4, 8) -- 2 batches, 4 rows, 8 columns + * DimAccessOrder = (0, 1, 2) -- batch outermost, column innermost + * ScalarPerAccess = (1, 1, 8) -- 8-wide vector loads on columns + * + * access_lengths = (2, 4, 1) + * Total accesses = 2 * 4 * 1 = 8 + * + * Traversal within each batch is serpentine on rows, columns scalar: + * + * batch 0: + * row 0: [0] -- (0, 0, 0) + * row 1: [1] -- (0, 1, 0) + * row 2: [2] -- (0, 2, 0) + * row 3: [3] -- (0, 3, 0) + * + * batch 1: (serpentine reversal on rows) + * row 3: [4] -- (1, 3, 0) + * row 2: [5] -- (1, 2, 0) + * row 1: [6] -- (1, 1, 0) + * row 0: [7] -- (1, 0, 0) + * + * Last access (#7) lands at index (1, 0, 0). + * Reset step = (0,0,0) - (1,0,0) = (-1, 0, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 1, 8>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1, 2>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -1); // return 1 batch + EXPECT_EQ(reset[Number<1>{}], 0); // row already at origin (serpentine came back) + EXPECT_EQ(reset[Number<2>{}], 0); // column at origin (single access per row) +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_EvenInnerAccesses) +{ + /* + * When the number of accesses along the inner dimension is even, the + * serpentine traversal returns to the starting side on that dimension. + * + * SliceLengths = (4, 4) + * DimAccessOrder = (0, 1) + * ScalarPerAccess = (1, 2) -- 2-wide vector loads + * + * access_lengths = (4, 2) -- 2 accesses along cols (even) + * + * col: 0..1 2..3 + * row 0: [0]-->[1] access 0 -> (0,0), access 1 -> (0,2) + * row 1: [3]<--[2] access 2 -> (1,2), access 3 -> (1,0) + * row 2: [4]-->[5] access 4 -> (2,0), access 5 -> (2,2) + * row 3: [7]<--[6] access 6 -> (3,2), access 7 -> (3,0) + * + * Last access (#7) at (3, 0). Even number of column accesses (2) + * means the serpentine always returns to col=0 at the end of each row. + * Reset step = (0,0) - (3,0) = (-3, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 2>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -3); + EXPECT_EQ(reset[Number<1>{}], 0); // even inner accesses -> back at start column +} + +TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_OddInnerAccesses) +{ + /* + * When the number of accesses along the inner dimension is odd and the + * outer dimension is even, the serpentine returns to col=0. + * + * SliceLengths = (2, 6) + * DimAccessOrder = (0, 1) + * ScalarPerAccess = (1, 2) -- 2-wide vector loads + * + * access_lengths = (2, 3) -- 3 accesses along cols (odd!) + * + * col: 0..1 2..3 4..5 + * row 0: [0]-->[1]-->[2] access 0 -> (0,0), 1 -> (0,2), 2 -> (0,4) + * row 1: [5]<--[4]<--[3] access 3 -> (1,4), 4 -> (1,2), 5 -> (1,0) + * + * Last access (#5) at (1, 0). Even row count means serpentine reversal + * on the inner dim brings us back to col=0. + * Reset step = (0,0) - (1,0) = (-1, 0) + */ + using SFCHelper = ThreadwiseTransferHelper_SFC; + + constexpr auto scalar_per_access = Sequence<1, 2>{}; + constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep, + Sequence<0, 1>, + decltype(scalar_per_access)>(); + + EXPECT_EQ(reset[Number<0>{}], -1); // return 1 row + EXPECT_EQ(reset[Number<1>{}], 0); // even outer accesses -> serpentine came back to col=0 +} + +// ============================================================================= +// Inheritance structure tests +// ============================================================================= + +TEST(ThreadwiseTransferHelperInheritance, SerpentineIsDerivedFromBase) +{ + /* + * ThreadwiseTransferHelper_Base + * | + * +-- ThreadwiseTransferHelper_Serpentine <-- this relationship + * | + * +-- ThreadwiseTransferHelper_SFC + */ + static_assert( + std::is_base_of_v); +} + +TEST(ThreadwiseTransferHelperInheritance, SFCIsDerivedFromBase) +{ + /* + * ThreadwiseTransferHelper_Base + * | + * +-- ThreadwiseTransferHelper_Serpentine + * | + * +-- ThreadwiseTransferHelper_SFC <-- this relationship + */ + static_assert(std::is_base_of_v); +} + +TEST(ThreadwiseTransferHelperInheritance, SerpentineAndSFCAreNotRelated) +{ + /* + * Serpentine and SFC are siblings -- neither inherits from the other. + * + * ThreadwiseTransferHelper_Base + * | + * +-- Serpentine (NOT parent of SFC) + * | + * +-- SFC (NOT parent of Serpentine) + */ + static_assert( + !std::is_base_of_v); + static_assert( + !std::is_base_of_v); +} + +// ============================================================================= +// detail:: functor tests +// ============================================================================= + +TEST(DetailFunctors, LambdaScalarPerAccess) +{ + /* + * For VectorDim=1 and ScalarPerVector=8: + * + * dim: 0 1 2 + * result: 1 8 1 + * ^ ^ ^ + * | | +-- not the vector dim + * | +------ THE vector dim (returns ScalarPerVector) + * +---------- not the vector dim + */ + constexpr auto f = detail::lambda_scalar_per_access<1, 8>{}; + + EXPECT_EQ(f(0), 1); + EXPECT_EQ(f(1), 8); + EXPECT_EQ(f(2), 1); +} + +TEST(DetailFunctors, LambdaScalarStepInVector) +{ + /* + * For VectorDim=2: + * + * dim: 0 1 2 3 + * result: 0 0 1 0 + * ^ + * +-- THE vector dim (step = 1) + */ + constexpr auto f = detail::lambda_scalar_step_in_vector<2>{}; + + EXPECT_EQ(f(0), 0); + EXPECT_EQ(f(1), 0); + EXPECT_EQ(f(2), 1); + EXPECT_EQ(f(3), 0); +} + +TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_SameDim) +{ + /* + * Src and Dst both vectorize dim 1: + * SrcVectorDim=1, SrcScalarPerVector=4 + * DstVectorDim=1, DstScalarPerVector=8 + * + * dim: 0 1 2 + * result: 1 lcm(4,8) 1 + * = 8 + */ + constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<1, 4, 1, 8>{}; + + EXPECT_EQ(f(0), 1); + EXPECT_EQ(f(1), 8); // lcm(4, 8) = 8 + EXPECT_EQ(f(2), 1); +} + +TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_DifferentDims) +{ + /* + * Src vectorizes dim 0 (width 4), Dst vectorizes dim 2 (width 8): + * + * dim: 0 1 2 + * result: 4(src) 1 8(dst) + */ + constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<0, 4, 2, 8>{}; + + EXPECT_EQ(f(0), 4); // src vector dim + EXPECT_EQ(f(1), 1); // neither + EXPECT_EQ(f(2), 8); // dst vector dim +}