[rocm-libraries] ROCm/rocm-libraries#4673 (commit ec385da)

Compile-time optimize threadwise slice transfer
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

Profiling with `-ftime-trace` on representative translation units (e.g.,

`device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp`)
revealed
that **92% of frontend time was spent in template instantiation**. The
primary
bottleneck was redundant instantiation of identical helper logic across
multiple
threadwise transfer class variants.

Each `ThreadwiseTensorSliceTransfer_v*` class independently contained
its own
copy of the same helper computations — serpentine traversal, coordinate
stepping, thread scratch descriptors, lambda-like functors, and
compile-time
constants — duplicated across 13 header files. When a typical GEMM or
convolution kernel TU includes blockwise operations (e.g.,
`blockwise_gemm_xdlops.hpp`), it pulls in multiple transfer variants
simultaneously, causing the compiler to instantiate the same helper
logic
multiple times with the same template arguments.

This was compounded by the helpers being defined as members of the outer
`ThreadwiseTensorSliceTransfer_v*` classes, which carry 14+ template
parameters.
Functions like `ComputeForwardSweep` depend only on their two argument
types,
but as inline members of the outer class, the compiler was forced to
create
separate instantiations for every unique combination of all outer
parameters
(data types, descriptors, vector widths, etc.) — even when most of those
parameters had no effect on the helper's output.

## Technical Details

### The Fix: Shared Helper Struct Hierarchy

Duplicated logic was extracted into a standalone helper hierarchy in
`threadwise_tensor_slice_transfer_util.hpp`:

```
ThreadwiseTransferHelper_Base          (I0..I16, MoveSliceWindow, ComputeThreadScratchDescriptor,
|                                       ComputeForwardSteps, ComputeBackwardSteps, MakeVectorContainerTuple)
+-- ThreadwiseTransferHelper_Serpentine (ComputeForwardSweep, ComputeMoveOnDim, ComputeDataIndex,
|                                       ComputeCoordinateResetStep, VectorSizeLookupTable, VectorOffsetsLookupTable)
+-- ThreadwiseTransferHelper_SFC       (ComputeSFCCoordinateResetStep)
```

Each helper method is now parameterized **only by what it actually
uses**:

- `ComputeForwardSweep(idx, lengths)` — parameterized only by the two
argument
  types, not by `SrcData`, `DstData`, `SrcDesc`, etc.
- `ComputeForwardSteps(desc, scalar_per_access)` — parameterized only by
the
  descriptor and access sequence types.
- `ComputeCoordinateResetStep<SliceLengths, VectorDim, ScalarPerVector,
DimAccessOrder>()` — parameterized only by the four values it actually
needs.

This reduces template instantiation work in two ways:
1. **Across different transfer variants** (v3r1 vs v3r2 vs v3r1_gather):
the
compiler reuses a single instantiation instead of creating one per
variant.
2. **Across different outer class instantiations** (fp16 vs bf16 vs
int8): the
compiler reuses the helper instantiation because the helper doesn't
depend
   on the data type at all.

### Refactored Headers

**13 headers** now delegate to the shared helpers instead of duplicating
logic:
- Serpentine family: v3r1, v3r2, v3r1_gather, v3r1_dequant
- SFC family: v6r1, v6r1r2, v6r2, v6r3, v7r2, v7r3, v7r3_scatter
- Dead code removed: v4r1, v5r1

### Additional Fixes Found During Refactoring

- Two latent bugs in v3r2 (`forward_sweep` indexing,
`GetDstCoordinateResetStep` extraction)
- Dead `SrcCoordStep` variables in v4r1 and v5r1
- Unused `scale_element_op_` member in v3r1_dequant (restored with note)

### Net Code Change

+1,428 / -2,297 lines (~870 lines removed).

## Test Plan

### Unit Tests

28 host-side gtests in
`test/threadwise_transfer_helper/test_threadwise_transfer_helper.cpp`
covering the full helper hierarchy:

| Suite | Tests | What is verified |
|-------|-------|------------------|
| ThreadwiseTransferHelperBase | 6 | Compile-time constants,
inheritance, `MoveSliceWindow` with `ResetCoordinateAfterRun` true/false
in 2D and 3D |
| ThreadwiseTransferHelperSerpentine | 9 | `ComputeForwardSweep`
(even/odd row, 1D), `ComputeMoveOnDim` (inner complete/incomplete),
`ComputeDataIndex`, `ComputeCoordinateResetStep`,
`VectorSizeLookupTable`, `VectorOffsetsLookupTable` |
| ThreadwiseTransferHelperSFC | 6 | `ComputeSFCCoordinateResetStep` —
single access, 2D row-major, 2D column-major, 3D batch, even/odd inner
access counts |
| ThreadwiseTransferHelperInheritance | 3 | Serpentine and SFC derive
from Base, are not related to each other |
| DetailFunctors | 4 | `lambda_scalar_per_access`,
`lambda_scalar_step_in_vector`,
`lambda_scalar_per_access_for_src_and_dst` (same dim, different dims) |

### Semantic Equivalence

GPU ISA comparison using `--cuda-device-only -S` confirmed identical
assembly
output (modulo `__hip_cuid_*` metadata) between baseline and refactored
code.

## Test Results

All measurements on a 384-core machine, `-j64`, freshly rebooted,
near-idle.

### Targeted Builds (affected targets only)

| Target | Baseline | Refactored | Wall-clock Delta | CPU Delta |
|--------|----------|------------|-----------------|-----------|
| `device_grouped_conv2d_fwd_instance` (160 TUs) | 7m 37s / 189m CPU |
6m 53s / 161m CPU | **-9.7%** | **-14.9%** |
| `device_grouped_conv3d_fwd_instance` (185 TUs) | 9m 49s / 202m CPU |
6m 42s / 182m CPU | **-31.8%** | **-10.0%** |
| **Combined** | **17m 27s / 392m CPU** | **13m 35s / 344m CPU** |
**-22.2%** | **-12.4%** |

### Full Project Build (8,243 targets)

| Metric | Baseline | Refactored | Delta |
|--------|----------|------------|-------|
| Wall-clock | 103m 38s | 111m 56s | +8.0%* |
| CPU time | 4705m 7s | 4648m 17s | **-1.2%** |

\*Wall-clock inflated by external load spike during refactored build
(load 90 vs 66). CPU time is the reliable metric.

### Context

~15% of all build targets (1,262 / 8,243) transitively include the
modified
headers. These are primarily GEMM and convolution kernel instantiations
— the
core compute workloads. The 12-15% CPU savings on affected targets is
diluted
to 1.2% across the full project because 85% of targets are unaffected.

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Christopher Millette
2026-03-06 16:27:59 +00:00
committed by assistant-librarian[bot]
parent b80e41f3bc
commit e2ce0cad54
17 changed files with 1747 additions and 2282 deletions

View File

@@ -1,19 +1,28 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
/**
* @file threadwise_tensor_slice_transfer_util.hpp
* @brief Shared helper class hierarchy for threadwise tensor slice transfer variants.
*
* Provides a three-tier inheritance structure:
*
* - @ref ThreadwiseTransferHelper_Base -- generic coordinate/descriptor utilities
* - @ref ThreadwiseTransferHelper_Serpentine -- serpentine (snake/zigzag) traversal
* - @ref ThreadwiseTransferHelper_SFC -- SpaceFillingCurve traversal
*/
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
/** @brief Functor returning ScalarPerVector for dimension VectorDim, 1 otherwise. */
template <index_t VectorDim, index_t ScalarPerVector>
struct lambda_scalar_per_access
{
@@ -23,6 +32,7 @@ struct lambda_scalar_per_access
}
};
/** @brief Functor returning 1 for dimension VectorDim, 0 otherwise. */
template <index_t VectorDim>
struct lambda_scalar_step_in_vector
{
@@ -32,8 +42,10 @@ struct lambda_scalar_step_in_vector
}
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
/**
* @brief Functor computing scalar-per-access for combined src/dst vector dimensions.
* Returns lcm when both src and dst share the same vector dimension.
*/
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
@@ -61,18 +73,449 @@ struct lambda_scalar_per_access_for_src_and_dst
}
};
template <index_t WaveNum, index_t nDim>
struct lambda_wave_cluster_dimension
} // namespace detail
/**
* @brief Base helper with methods shared by all threadwise transfer variants.
*
* Both ThreadwiseTransferHelper_Serpentine and ThreadwiseTransferHelper_SFC
* inherit from this class. Contains generic coordinate stepping, thread scratch
* descriptor construction, and compile-time index constants.
*/
struct ThreadwiseTransferHelper_Base
{
__host__ __device__ constexpr auto operator()(index_t i) const
/**
* @name Compile-time index constants
* @{
*/
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
/** @} */
/**
* @brief Move the slice window by a step, optionally fusing coordinate reset.
*
* If the coordinate was not reset after RunRead/RunWrite, the reset step is
* added to the movement step to avoid a separate coordinate adjustment.
*
* @tparam ResetCoordinateAfterRun Whether the coordinate was already reset.
* @param desc Tensor descriptor.
* @param coord Tensor coordinate to move (modified in place).
* @param slice_origin_step_idx Step index for the slice window movement.
* @param get_reset_step Callable returning the coordinate reset step.
*/
template <typename Desc,
typename Coord,
bool ResetCoordinateAfterRun,
typename StepIdx,
typename GetCoordinateResetStepFunc>
__host__ __device__ static void MoveSliceWindow(const Desc& desc,
Coord& coord,
const StepIdx& slice_origin_step_idx,
GetCoordinateResetStepFunc get_reset_step)
{
if((nDim - i) == 3)
return WaveNum;
else
return 1;
const auto adjusted_step_idx = ResetCoordinateAfterRun
? slice_origin_step_idx
: slice_origin_step_idx + get_reset_step();
const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx);
move_tensor_coordinate(desc, coord, adjusted_step);
}
/**
* @brief Build the thread-local scratch tensor descriptor.
*
* Creates a transformed tensor descriptor where the vector dimension is merged
* with an additional dimension of size ScalarPerVector, enabling vector-typed
* access to the scratch buffer.
*
* @tparam SliceLengths Compile-time sequence of per-dimension slice lengths.
* @tparam VectorDim Which dimension is vectorized.
* @tparam ScalarPerVector_ Number of scalars per vector load/store.
* @return Transformed tensor descriptor for the thread scratch buffer.
*/
template <typename SliceLengths, index_t VectorDim, index_t ScalarPerVector_>
__host__ __device__ static constexpr auto ComputeThreadScratchDescriptor()
{
constexpr index_t nDim = SliceLengths::Size();
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector_>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
constexpr auto access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(access_lengths), Number<ScalarPerVector_>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(access_lengths_and_vector_length);
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == VectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(access_lengths_and_vector_length[i],
access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == VectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
/**
* @brief Compute forward (+1) coordinate steps for each dimension.
*
* Returns a tuple of nDim coordinate steps, where step[i] moves by
* +scalar_per_access[i] in dimension i and 0 in all other dimensions.
*
* @param desc Tensor descriptor.
* @param scalar_per_access Per-dimension access widths (Sequence type).
*/
template <typename Desc, typename ScalarPerAccess>
__host__ __device__ static constexpr auto
ComputeForwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access)
{
constexpr index_t nDim = ScalarPerAccess::Size();
return generate_tuple(
[&](auto i) {
MultiIndex<nDim> step_idx;
static_for<0, nDim, 1>{}(
[&](auto j) { step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0; });
return make_tensor_coordinate_step(desc, step_idx);
},
Number<nDim>{});
}
/**
* @brief Compute backward (-1) coordinate steps for each dimension.
*
* Same as ComputeForwardSteps but with negated step values.
*
* @param desc Tensor descriptor.
* @param scalar_per_access Per-dimension access widths (Sequence type).
*/
template <typename Desc, typename ScalarPerAccess>
__host__ __device__ static constexpr auto
ComputeBackwardSteps(const Desc& desc, const ScalarPerAccess& scalar_per_access)
{
constexpr index_t nDim = ScalarPerAccess::Size();
return generate_tuple(
[&](auto i) {
MultiIndex<nDim> step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, step_idx);
},
Number<nDim>{});
}
/**
* @brief Create a tuple of default-constructed vector containers, one per data type.
*
* @tparam DataTypes Tuple of data types (e.g., Tuple<fp16_t, fp16_t>).
* @tparam ScalarPerVector Number of scalars per vector.
* @return Tuple of vector_type_maker_t<DataType, ScalarPerVector> instances.
*/
template <typename DataTypes, index_t ScalarPerVector>
__host__ __device__ static auto MakeVectorContainerTuple()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
}
};
} // namespace detail
/**
* @brief Serpentine (snake/zigzag) traversal helper.
*
* Provides methods for computing serpentine sweep directions, dimension movement
* decisions, and coordinate reset steps used by the v3r1 family of transfer classes.
*
* Used by: ThreadwiseTensorSliceTransfer_v3r1, v3r2, v3r1_gather, v3r1_dequant.
*/
struct ThreadwiseTransferHelper_Serpentine : ThreadwiseTransferHelper_Base
{
/**
* @brief Binary decomposition of vector widths 0-16 into power-of-2 sub-load sizes.
* Index N gives the sequence of sub-load widths whose sum equals N.
* E.g. index 7 -> Sequence<I4, I2, I1> means loads of width 4, 2, 1.
*/
using VectorSizeLookupTable = Tuple<Sequence<>,
Sequence<I1>,
Sequence<I2>,
Sequence<I2, I1>,
Sequence<I4>,
Sequence<I4, I1>,
Sequence<I4, I2>,
Sequence<I4, I2, I1>,
Sequence<I8>,
Sequence<I8, I1>,
Sequence<I8, I2>,
Sequence<I8, I2, I1>,
Sequence<I8, I4>,
Sequence<I8, I4, I1>,
Sequence<I8, I4, I2>,
Sequence<I8, I4, I2, I1>,
Sequence<I16>>;
/**
* @brief Starting offsets for each sub-load in VectorSizeLookupTable.
* E.g. index 7 -> Sequence<I0, I4, I6> means offsets 0, 4, 6.
*/
using VectorOffsetsLookupTable = Tuple<Sequence<>,
Sequence<I0>,
Sequence<I0>,
Sequence<I0, I2>,
Sequence<I0>,
Sequence<I0, I4>,
Sequence<I0, I4>,
Sequence<I0, I4, I6>,
Sequence<I0>,
Sequence<I0, I8>,
Sequence<I0, I8>,
Sequence<I0, I8, I10>,
Sequence<I0, I8>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12, I14>,
Sequence<I0>>;
/**
* @brief Compute serpentine sweep direction for each dimension.
*
* Determines whether each dimension should be traversed forward or backward
* based on the current position in the ordered access grid, implementing
* a zigzag (serpentine) traversal pattern.
*
* @param ordered_access_idx Current position in the ordered access grid.
* @param ordered_access_lengths Size of the ordered access grid per dimension.
* @return Array of booleans: true = forward, false = backward.
*/
template <typename OrderedAccessIdx, typename OrderedAccessLengths>
__host__ __device__ static constexpr auto
ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths)
{
constexpr index_t nDim = OrderedAccessLengths::Size();
static_assert(OrderedAccessIdx::Size() == nDim,
"ordered_access_idx and ordered_access_lengths must have same nDim");
StaticallyIndexedArray_v2<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}
/**
* @brief Determine which dimensions need coordinate movement at a given iteration.
*
* A dimension moves when it hasn't reached its end and all higher-priority
* (faster-varying) dimensions have completed their ranges.
*
* @param ordered_access_idx Current position in the ordered access grid.
* @param ordered_access_lengths Size of the ordered access grid per dimension.
* @return Array of booleans: true = move coordinate on this dimension.
*/
template <typename OrderedAccessIdx, typename OrderedAccessLengths>
__host__ __device__ static constexpr auto
ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths)
{
constexpr index_t nDim = OrderedAccessLengths::Size();
static_assert(OrderedAccessIdx::Size() == nDim,
"ordered_access_idx and ordered_access_lengths must have same nDim");
StaticallyIndexedArray_v2<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
/**
* @brief Convert ordered access index to natural dimension order and apply scaling.
*
* @param ordered_access_idx Current position in the ordered access grid.
* @param ordered_access_lengths Size of the ordered access grid per dimension.
* @param forward_sweep Per-dimension sweep direction.
* @param dim_access_order Mapping from ordered to natural dimension indices.
* @param scalar_per_access Per-dimension access widths.
* @return MultiIndex in natural dimension order, scaled by scalar_per_access.
*/
template <typename OrderedAccessIdx,
typename OrderedAccessLengths,
typename ForwardSweep,
typename DimAccessOrder,
typename ScalarPerAccess>
__host__ __device__ static constexpr auto
ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths,
const ForwardSweep& forward_sweep,
const DimAccessOrder& dim_access_order,
const ScalarPerAccess& scalar_per_access)
{
constexpr index_t nDim = ScalarPerAccess::Size();
static_assert(OrderedAccessIdx::Size() == nDim,
"all arguments to ComputeDataIndex must have same nDim");
MultiIndex<nDim> ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access;
}
/**
* @brief Compute the coordinate step needed to return to the origin after traversal.
*
* Determines where the coordinate ends up after a full serpentine traversal,
* then returns the negated position as the reset step.
*
* @tparam SliceLengths Compile-time sequence of per-dimension slice lengths.
* @tparam VectorDim Which dimension is vectorized.
* @tparam ScalarPerVector_ Number of scalars per vector load/store.
* @tparam DimAccessOrder Compile-time sequence mapping ordered to natural dims.
* @return MultiIndex representing the step to reset the coordinate to the origin.
*/
template <typename SliceLengths,
index_t VectorDim,
index_t ScalarPerVector_,
typename DimAccessOrder>
__host__ __device__ static constexpr auto ComputeCoordinateResetStep()
{
constexpr index_t nDim = SliceLengths::Size();
static_assert(DimAccessOrder::Size() == nDim,
"SliceLengths and DimAccessOrder must have same nDim");
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector_>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
constexpr auto ordered_access_lengths_minus_1 = generate_tuple(
[&](auto i) { return Number<ordered_access_lengths.At(i) - 1>{}; }, Number<nDim>{});
constexpr auto forward_sweep =
ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths);
constexpr auto reset_step = [&]() {
MultiIndex<nDim> ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
auto data_idx =
container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access;
MultiIndex<nDim> step;
static_for<0, nDim, 1>{}([&](auto i) { step(i) = -data_idx[i]; });
return step;
}();
return reset_step;
}
};
/**
* @brief SpaceFillingCurve traversal helper.
*
* Provides coordinate reset computation using SpaceFillingCurve's GetStepBetween
* method, which computes the step from the last access position back to the origin.
*
* Used by: ThreadwiseTensorSliceTransfer v6r1, v6r1r2, v6r2, v6r3, v7r2, v7r3,
* v7r3_scatter.
*/
struct ThreadwiseTransferHelper_SFC : ThreadwiseTransferHelper_Base
{
/**
* @brief Compute the coordinate reset step using SpaceFillingCurve traversal.
*
* @tparam SliceLengths Compile-time sequence of per-dimension slice lengths.
* @tparam DimAccessOrder Compile-time sequence defining dimension access order.
* @tparam ScalarPerAccess Compile-time sequence of per-dimension access widths.
* @return MultiIndex representing the step from last access position to origin.
*/
template <typename SliceLengths, typename DimAccessOrder, typename ScalarPerAccess>
__host__ __device__ static constexpr auto ComputeSFCCoordinateResetStep()
{
using SFC = SpaceFillingCurve<SliceLengths, DimAccessOrder, remove_cv_t<ScalarPerAccess>>;
constexpr auto num_access = SFC::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SFC::Index{};
}
else
{
return SFC::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
}
}
};
} // namespace ck

View File

@@ -44,263 +44,63 @@ template <typename SliceLengths,
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v3r1
{
// =====================================================================
// Private type aliases and constants
// =====================================================================
private:
using Helper = ThreadwiseTransferHelper_Serpentine;
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t PackedSize = is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ? 2 : 1;
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
// =====================================================================
// Private implementation methods (must be declared before public methods
// that call them)
// =====================================================================
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
return Helper::ComputeCoordinateResetStep<SliceLengths,
SrcVectorDim,
SrcScalarPerVector_,
SrcDimAccessOrder>();
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
__device__ static constexpr auto GetDstCoordinateResetStep()
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
return Helper::ComputeCoordinateResetStep<SliceLengths,
DstVectorDim,
DstScalarPerVector_,
DstDimAccessOrder>();
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector_>();
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward and backward steps
const auto src_forward_steps = ComputeForwardSteps(src_desc, src_scalar_per_access);
const auto src_backward_steps = ComputeBackwardSteps(src_desc, src_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
// calculate src data index
constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
else
{
return 1;
}
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
using VectorSizeLookupTable = Tuple<Sequence<>,
Sequence<I1>,
Sequence<I2>,
Sequence<I2, I1>,
Sequence<I4>,
Sequence<I4, I1>,
Sequence<I4, I2>,
Sequence<I4, I2, I1>,
Sequence<I8>,
Sequence<I8, I1>,
Sequence<I8, I2>,
Sequence<I8, I2, I1>,
Sequence<I8, I4>,
Sequence<I8, I4, I1>,
Sequence<I8, I4, I2>,
Sequence<I8, I4, I2, I1>,
Sequence<I16>>;
using VectorOffsetsLookupTable = Tuple<Sequence<>,
Sequence<I0>,
Sequence<I0>,
Sequence<I0, I2>,
Sequence<I0>,
Sequence<I0, I4>,
Sequence<I0, I4>,
Sequence<I0, I4, I6>,
Sequence<I0>,
Sequence<I0, I8>,
Sequence<I0, I8>,
Sequence<I0, I8, I10>,
Sequence<I0, I8>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12>,
Sequence<I0, I8, I12, I14>,
Sequence<I0>>;
static_for<0, tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::Size(), 1>{}(
[&](auto v_idx) {
constexpr auto VectorLoadSize =
tuple_element_t<SrcScalarPerVector, VectorSizeLookupTable>::At(v_idx);
constexpr auto LoadOffset =
tuple_element_t<SrcScalarPerVector, VectorOffsetsLookupTable>::At(v_idx);
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
using src_vector_container_t = typename src_vector_container::type;
src_vector_container src_vector =
src_vector_container{src_buf.template Get<src_vector_container_t>(
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if
// needed
src_element_op_(
op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
});
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<dst_vector_t>(src_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim =
ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths);
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
return make_naive_tensor_descriptor_packed(src_access_lengths);
}
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector_>();
}
template <index_t ThreadScratchId>
@@ -327,14 +127,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
// calculate src data index
constexpr auto src_data_idx = ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
@@ -439,6 +239,194 @@ struct ThreadwiseTensorSliceTransfer_v3r1
#endif
}
// =====================================================================
// Public interface
// =====================================================================
public:
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const SrcElementwiseOperation& src_element_op,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const DstElementwiseOperation& dst_element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
static_assert(
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"wrong!");
static_assert(
is_same<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>::value,
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward and backward steps
const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access);
const auto src_backward_steps =
Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
// calculate src data index
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
dst_vector_type op_r_v;
constexpr auto get_elem_op_vec_len = []() {
if constexpr(is_detected<is_pack8_invocable_t, decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack8_invocable)
return math::min(8, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack4_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack4_invocable)
return math::min(4, SrcScalarPerVector);
}
else if constexpr(is_detected<is_pack2_invocable_t,
decltype(src_element_op_)>::value)
{
if constexpr(decltype(src_element_op_)::is_pack2_invocable)
return math::min(2, SrcScalarPerVector);
}
else
{
return 1;
}
};
constexpr index_t elem_op_vec_len = get_elem_op_vec_len();
using src_elem_op_vec_t = typename vector_type<SrcData, elem_op_vec_len>::type;
using dst_elem_op_vec_t = typename vector_type<DstData, elem_op_vec_len>::type;
static_for<0,
tuple_element_t<SrcScalarPerVector, Helper::VectorSizeLookupTable>::Size(),
1>{}([&](auto v_idx) {
constexpr auto VectorLoadSize =
tuple_element_t<SrcScalarPerVector, Helper::VectorSizeLookupTable>::At(v_idx);
constexpr auto LoadOffset =
tuple_element_t<SrcScalarPerVector, Helper::VectorOffsetsLookupTable>::At(
v_idx);
using src_vector_container = vector_type_maker_t<SrcData, VectorLoadSize>;
using src_vector_container_t = typename src_vector_container::type;
src_vector_container src_vector =
src_vector_container{src_buf.template Get<src_vector_container_t>(
src_coord_.GetOffset() / PackedSize + LoadOffset, true)};
static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) {
// apply the src elementwise op and convert to DstData under the hood if
// needed
src_element_op_(op_r_v.template AsType<dst_elem_op_vec_t>()(idx + LoadOffset),
src_vector.template AsType<src_elem_op_vec_t>()[idx]);
});
});
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<dst_vector_t>(
src_data_idx_seq, op_r_v.template AsType<dst_vector_t>()[Helper::I0]);
constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths);
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if constexpr(move_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(
src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
}
else
{
move_tensor_coordinate(
src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
template <typename DstBuffer, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDesc& dst_desc,
DstBuffer& dst_buf,
@@ -470,21 +458,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward and backward steps
const auto dst_forward_steps = ComputeForwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_backward_steps = ComputeBackwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_backward_steps =
Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep =
ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths);
Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths);
// calculate dst data index
constexpr auto dst_data_idx = ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths,
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access);
constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths,
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access);
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
@@ -510,10 +499,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize,
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[Helper::I0]);
constexpr auto move_on_dim =
ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths);
Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths);
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -543,21 +532,19 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
return ComputeCoordinateResetStep<SrcVectorDim, SrcScalarPerVector_, SrcDimAccessOrder>();
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
return ComputeCoordinateResetStep<DstVectorDim, DstScalarPerVector_, DstDimAccessOrder>();
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
Helper::MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep);
}
@@ -565,252 +552,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
Helper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
return ComputeThreadScratchDescriptor<SrcVectorDim, SrcScalarPerVector_>();
}
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
return make_naive_tensor_descriptor_packed(src_access_lengths);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
return ComputeThreadScratchDescriptor<DstVectorDim, DstScalarPerVector_>();
}
protected:
// Helper function to compute forward sweep pattern
// I.e. if we should move forward or backward in each of tensor's dimensions
template <typename OrderedAccessIdx, typename OrderedAccessLengths>
__device__ static constexpr auto
ComputeForwardSweep(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths)
{
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_access_idx[I0];
static_for<1, i, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}
// Compute which dimensions should have their coordinates updated during iteration
// A dimension moves when it hasn't reached its end and all higher priority dimensions
// have completed their ranges
template <typename OrderedAccessIdx, typename OrderedAccessLengths>
__device__ static constexpr auto
ComputeMoveOnDim(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths)
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
// Compute data index from ordered access index, converting back to natural order
template <typename OrderedAccessIdx,
typename OrderedAccessLengths,
typename ForwardSweep,
typename DimAccessOrder,
typename ScalarPerAccess>
__device__ static constexpr auto
ComputeDataIndex(const OrderedAccessIdx& ordered_access_idx,
const OrderedAccessLengths& ordered_access_lengths,
const ForwardSweep& forward_sweep,
const DimAccessOrder& dim_access_order,
const ScalarPerAccess& scalar_per_access)
{
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_access_idx[i]
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) * scalar_per_access;
}
// Compute forward coordinate steps for each dimension
template <typename Desc, typename ScalarPerAccess>
__device__ static constexpr auto ComputeForwardSteps(const Desc& desc,
const ScalarPerAccess& scalar_per_access)
{
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, forward_step_idx);
},
Number<nDim>{});
}
// Compute backward coordinate steps for each dimension
template <typename Desc, typename ScalarPerAccess>
__device__ static constexpr auto ComputeBackwardSteps(const Desc& desc,
const ScalarPerAccess& scalar_per_access)
{
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(desc, backward_step_idx);
},
Number<nDim>{});
}
// Generic helper to compute thread scratch descriptor
template <index_t VectorDim, index_t ScalarPerVector_>
__device__ static constexpr auto ComputeThreadScratchDescriptor()
{
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector_>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
constexpr auto access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(access_lengths), Number<ScalarPerVector_>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == VectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(access_lengths_and_vector_length[i],
access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == VectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
// Generic helper to move slice window
template <typename Desc,
typename Coord,
bool ResetCoordinateAfterRun,
typename GetCoordinateResetStepFunc>
__device__ static void MoveSliceWindow(const Desc& desc,
Coord& coord,
const Index& slice_origin_step_idx,
GetCoordinateResetStepFunc get_reset_step)
{
// if coord was not reset by RunRead/RunWrite(), then need to adjust the step here
const auto adjusted_step_idx = ResetCoordinateAfterRun
? slice_origin_step_idx
: slice_origin_step_idx + get_reset_step();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(desc, adjusted_step_idx);
move_tensor_coordinate(desc, coord, adjusted_step);
}
// Generic helper to compute coordinate reset step
template <index_t VectorDim, index_t ScalarPerVector_, typename DimAccessOrder>
__device__ static constexpr auto ComputeCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector_>{}, Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
constexpr auto dim_access_order = DimAccessOrder{};
constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto ordered_access_lengths_minus_1 = generate_tuple(
[&](auto i) { return Number<ordered_access_lengths.At(i) - 1>{}; }, Number<nDim>{});
constexpr auto forward_sweep =
ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_access_lengths);
// calculate data index after last iteration, if it has not being reset
constexpr auto data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
scalar_per_access;
}();
//
constexpr auto reset_data_step = [&]() {
Index reset_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_data_step_(i) = -data_idx[i]; });
return reset_data_step_;
}();
return reset_data_step;
}
// =====================================================================
// Private data members
// =====================================================================
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =

View File

@@ -7,43 +7,12 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
namespace detail {
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template <index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t DstVectorDim,
index_t DstScalarPerVector>
struct lambda_scalar_per_access_for_src_and_dst_idle
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if(i == SrcVectorDim && i == DstVectorDim)
{
return math::lcm(SrcScalarPerVector, DstScalarPerVector);
}
else if(i == SrcVectorDim)
{
return SrcScalarPerVector;
}
else if(i == DstVectorDim)
{
return DstScalarPerVector;
}
else
{
return 1;
}
}
};
} // namespace detail
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
@@ -84,12 +53,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using Helper = ThreadwiseTransferHelper_Serpentine;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant(
const SrcDesc& src_desc,
const Index& src_slice_origin,
@@ -139,7 +108,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
@@ -150,66 +118,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// make forward and backward steps
const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access);
const auto src_backward_steps =
Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
@@ -227,22 +150,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
src_data_idx_seq,
src_vector_container.template AsType<src_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths);
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -284,7 +196,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
"wrong! ScaleBuffer and ScaleData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
@@ -295,66 +206,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
constexpr auto ordered_scale_access_lengths =
container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order);
// make forward steps
const auto scale_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto scale_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(scale_desc, backward_step_idx);
},
Number<nDim>{});
// make forward and backward steps
const auto scale_forward_steps =
Helper::ComputeForwardSteps(scale_desc, scale_scalar_per_access);
const auto scale_backward_steps =
Helper::ComputeBackwardSteps(scale_desc, scale_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_scale_access_lengths)>{}([&](auto ordered_scale_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_scale_access_idx, ordered_scale_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_scale_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate scale data index
constexpr auto scale_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i]
: ordered_scale_access_lengths[i] - 1 -
ordered_scale_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) *
scale_scalar_per_access;
}();
constexpr auto scale_data_idx = Helper::ComputeDataIndex(ordered_scale_access_idx,
ordered_scale_access_lengths,
forward_sweep,
scale_dim_access_order,
scale_scalar_per_access);
constexpr auto scale_data_idx_seq =
generate_sequence_v2([&](auto i) { return Number<scale_data_idx[i]>{}; },
@@ -372,23 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// copy data from scale_vector_container into scale_thread_scratch_
scale_thread_scratch_.template SetAsType<scale_vector_t>(
scale_data_idx_seq, scale_vector_container.template AsType<scale_vector_t>()[I0]);
scale_data_idx_seq,
scale_vector_container.template AsType<scale_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) =
ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_scale_access_idx, ordered_scale_access_lengths);
// move scale coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -409,17 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
}
});
});
// don't need to move scale coordinate back to slice origin
/*
if constexpr(SrcResetCoordinateAfterRun)
{
const auto scale_reset_step =
make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep());
move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step);
}
*/
}
template <index_t ThreadScratchId>
@@ -460,10 +304,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
@@ -504,10 +348,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
// Do fast numeric convert
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst_idle<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector,
DstVectorDim,
DstScalarPerVector>{},
Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
@@ -528,15 +372,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
src_converted_thread_scratch_.template SetAsType<src_converted_vector_t>(
access_idx,
src_converted_vector_container.template AsType<src_converted_vector_t>()[I0]);
src_converted_vector_container
.template AsType<src_converted_vector_t>()[Helper::I0]);
});
// Element-scale operation, expect packed multiplication
static_ford<SliceLengths>{}([&](auto idx) {
DstData dst_v;
constexpr auto scale_idx = Sequence<I0, idx.At(1), I0>{};
// printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<const uint16_t*>(&scale_thread_scratch_[scale_idx])));
constexpr auto scale_idx = Sequence<Helper::I0, idx.At(1), Helper::I0>{};
src_element_op_(dst_v,
src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]);
dst_thread_scratch_(idx) = dst_v;
@@ -562,7 +405,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
@@ -573,66 +415,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// make forward and backward steps
const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_backward_steps =
Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths,
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access);
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
@@ -660,22 +457,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths);
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -707,293 +492,52 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
return Helper::ComputeCoordinateResetStep<SliceLengths,
SrcVectorDim,
SrcScalarPerVector,
SrcDimAccessOrder>();
}
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
return Helper::ComputeCoordinateResetStep<SliceLengths,
DstVectorDim,
DstScalarPerVector,
DstDimAccessOrder>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
Helper::MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
Helper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector>();
}
__device__ static constexpr auto GetScaleThreadScratchDescriptor()
{
constexpr auto scale_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, ScaleScalarPerVector>{}, Number<nDim>{});
constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access;
constexpr auto scale_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(scale_access_lengths), Number<ScaleScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(scale_access_lengths_and_vector_length[i],
scale_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(scale_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, ScaleScalarPerVector>();
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector>();
}
private:
@@ -1002,11 +546,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
decltype(GetScaleThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
/*
template <bool kLastDim>
struct ScaleThreadScratchDesc{};
*/
// Registers, contain raw data loaded from global buffer
using SrcThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
SrcData,
@@ -1050,6 +589,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
ScaleCoord scale_coord_;
DstCoord dst_coord_;
const SrcElementwiseOperation src_element_op_;
// Note: scale_element_op_ appears unused but is retained for API completeness
const ScaleElementwiseOperation scale_element_op_;
const DstElementwiseOperation dst_element_op_;
};

View File

@@ -49,27 +49,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using Helper = ThreadwiseTransferHelper_Serpentine;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto I8 = Number<8>{};
static constexpr auto I10 = Number<10>{};
static constexpr auto I12 = Number<12>{};
static constexpr auto I13 = Number<13>{};
static constexpr auto I14 = Number<14>{};
static constexpr auto I16 = Number<16>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
@@ -142,7 +126,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
"wrong! SrcBuffer and SrcData data type are inconsistent");
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
@@ -156,66 +139,23 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// make forward steps
const auto src_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto src_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(src_desc, backward_step_idx);
},
Number<nDim>{});
// make forward and backward steps
const auto src_forward_steps = Helper::ComputeForwardSteps(src_desc, src_scalar_per_access);
const auto src_backward_steps =
Helper::ComputeBackwardSteps(src_desc, src_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
@@ -274,24 +214,20 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// copy data from src_vector_container into src_thread_scratch_
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<dst_vector_t>(src_data_idx_seq,
op_r_v.template AsType<dst_vector_t>()[I0]);
.template SetAsType<dst_vector_t>(
src_data_idx_seq, op_r_v.template AsType<dst_vector_t>()[Helper::I0]);
// Gather-specific: skip gather dimension during coordinate movement
auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
auto move_on_dim_ =
Helper::ComputeMoveOnDim(ordered_src_access_idx, ordered_src_access_lengths);
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
});
move_on_dim_(i) &= i.value != ordered_gather_dim;
});
static_for<0, nDim, 1>{}(
[&](auto i) { move_on_dim_(i) &= i.value != ordered_gather_dim; });
return move_on_dim_;
}();
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
if(move_on_dim[i])
@@ -351,38 +287,14 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_src_access_idx, ordered_src_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx = Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths,
forward_sweep,
src_dim_access_order,
src_scalar_per_access);
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
@@ -501,7 +413,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
"wrong! SrcBuffer or DstBuffer data type is wrong");
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
@@ -512,66 +423,21 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// make forward steps
const auto dst_forward_steps = generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, forward_step_idx);
},
Number<nDim>{});
// make backward steps
const auto dst_backward_steps = generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
});
return make_tensor_coordinate_step(dst_desc, backward_step_idx);
},
Number<nDim>{});
// make forward and backward steps
const auto dst_forward_steps = Helper::ComputeForwardSteps(dst_desc, dst_scalar_per_access);
const auto dst_backward_steps =
Helper::ComputeBackwardSteps(dst_desc, dst_scalar_per_access);
// loop over tensor and copy
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_dst_access_idx, ordered_dst_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
: ordered_dst_access_lengths[i] - 1 -
ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
constexpr auto dst_data_idx = Helper::ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths,
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access);
constexpr auto dst_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<dst_data_idx[i]>{}; }, Number<dst_data_idx.Size()>{});
@@ -599,22 +465,10 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset() / PackedSize,
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim =
Helper::ComputeMoveOnDim(ordered_dst_access_idx, ordered_dst_access_lengths);
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -644,10 +498,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
}
}
// Gather-specific: src coordinate reset zeroes the gather dimension
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
@@ -658,29 +511,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
constexpr auto ordered_access_lengths_minus_1 = generate_tuple(
[&](auto i) { return Number<ordered_src_access_lengths.At(i) - 1>{}; }, Number<nDim>{});
constexpr auto forward_sweep =
Helper::ComputeForwardSweep(ordered_access_lengths_minus_1, ordered_src_access_lengths);
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
MultiIndex<nDim> ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
@@ -690,9 +527,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
src_scalar_per_access;
}();
//
// Gather-specific: don't reset the gather dimension
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
MultiIndex<nDim> reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) {
reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i];
@@ -705,137 +542,32 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access;
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
return Helper::ComputeCoordinateResetStep<SliceLengths,
DstVectorDim,
DstScalarPerVector_,
DstDimAccessOrder>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
Helper::MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
src_desc, src_coord_, src_slice_origin_step_idx, GetSrcCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by RunWrite(), then need to adjust the step here
const auto adjusted_step_idx =
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
Helper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetDstCoordinateResetStep);
}
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector_>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector_>();
}
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
@@ -850,50 +582,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector_>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector_>();
}
private:

View File

@@ -7,10 +7,11 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/is_detected.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Assume:
@@ -48,6 +49,8 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
using Helper = ThreadwiseTransferHelper_Serpentine;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
@@ -61,8 +64,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2
using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray<Index, nSrc>{}));
using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray<Index, nDst>{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
@@ -101,7 +102,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access_tuple = generate_tuple(
[&](auto src_i) {
return generate_sequence(
@@ -129,40 +129,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nSrc>{});
// make forward steps
// make forward and backward steps
const auto src_forward_steps_tuple = generate_tuple(
[&](auto src_i) {
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) =
(i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0;
});
return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx);
},
Number<nDim>{});
return Helper::ComputeForwardSteps(src_descs.At(src_i),
src_scalar_per_access_tuple.At(src_i));
},
Number<nSrc>{});
// make backward steps
const auto src_backward_steps_tuple = generate_tuple(
[&](auto src_i) {
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value)
? -src_scalar_per_access_tuple.At(src_i)[i]
: 0;
});
return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx);
},
Number<nDim>{});
return Helper::ComputeBackwardSteps(src_descs.At(src_i),
src_scalar_per_access_tuple.At(src_i));
},
Number<nSrc>{});
@@ -171,39 +149,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_ford<remove_cvref_t<decltype(ordered_src_access_lengths_tuple.At(src_i))>>{}(
[&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths_tuple[j] +
ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
constexpr auto forward_sweep = Helper::ComputeForwardSweep(
ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i));
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_src_access_idx[i]
: ordered_src_access_lengths_tuple.At(src_i)[i] -
1 - ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access_tuple.At(src_i);
}();
constexpr auto src_data_idx =
Helper::ComputeDataIndex(ordered_src_access_idx,
ordered_src_access_lengths_tuple.At(src_i),
forward_sweep,
src_dim_access_order,
src_scalar_per_access_tuple.At(src_i));
constexpr auto src_data_idx_seq =
generate_sequence_v2([&](auto i) { return Number<src_data_idx[i]>{}; },
@@ -227,24 +182,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2
.At(src_i)
.template SetAsType<src_vector_t>(
src_data_idx_seq,
src_vector_container.template AsType<src_vector_t>()[I0]);
src_vector_container.template AsType<src_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_src_access_idx[i] <
ordered_src_access_lengths_tuple.At(src_i)[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_src_access_idx[j] ==
ordered_src_access_lengths_tuple.At(src_i)[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim = Helper::ComputeMoveOnDim(
ordered_src_access_idx, ordered_src_access_lengths_tuple.At(src_i));
// move src coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -287,18 +228,30 @@ struct ThreadwiseTensorSliceTransfer_v3r2
{
// TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE
// (it requires to add Elementwise support in transpose_vectors)
static_ford<SliceLengths>{}([&](auto idx) {
const auto src_data_refs = generate_tie(
[&](auto src_i) -> const auto& {
return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
},
Number<nSrc>{});
if constexpr(nSrc == 1 && nDst == 1)
{
// Fast path: direct element transfer, no generate_tie/unpack2 overhead
static_ford<SliceLengths>{}([&](auto idx) {
element_op_(dst_thread_scratch_tuple_.At(Number<0>{})(idx),
src_thread_scratch_tuple_[thread_scratch_id].At(Number<0>{})[idx]);
});
}
else
{
// General path: use generate_tie + unpack2 for multi-src/dst
static_ford<SliceLengths>{}([&](auto idx) {
const auto src_data_refs = generate_tie(
[&](auto src_i) -> const auto& {
return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx];
},
Number<nSrc>{});
auto dst_data_refs = generate_tie(
[&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); },
Number<nDst>{});
unpack2(element_op_, dst_data_refs, src_data_refs);
});
auto dst_data_refs = generate_tie(
[&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); },
Number<nDst>{});
unpack2(element_op_, dst_data_refs, src_data_refs);
});
}
}
template <typename DstBuffers, index_t ThreadScratchId = 0>
@@ -311,7 +264,6 @@ struct ThreadwiseTensorSliceTransfer_v3r2
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
// src scalar per access on each dim
// TODO: don't use this
constexpr auto dst_scalar_per_access_tuple = generate_tuple(
[&](auto dst_i) {
return generate_sequence(
@@ -334,40 +286,18 @@ struct ThreadwiseTensorSliceTransfer_v3r2
},
Number<nDst>{});
// make forward steps
// make forward and backward steps
const auto dst_forward_steps_tuple = generate_tuple(
[&](auto dst_i) {
return generate_tuple(
[&](auto i) {
Index forward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
forward_step_idx(j) =
(i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0;
});
return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx);
},
Number<nDim>{});
return Helper::ComputeForwardSteps(dst_descs.At(dst_i),
dst_scalar_per_access_tuple.At(dst_i));
},
Number<nDst>{});
// make backward steps
const auto dst_backward_steps_tuple = generate_tuple(
[&](auto dst_i) {
return generate_tuple(
[&](auto i) {
Index backward_step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
backward_step_idx(j) = (i.value == j.value)
? -dst_scalar_per_access_tuple.At(dst_i)[i]
: 0;
});
return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx);
},
Number<nDim>{});
return Helper::ComputeBackwardSteps(dst_descs.At(dst_i),
dst_scalar_per_access_tuple.At(dst_i));
},
Number<nDst>{});
@@ -376,39 +306,16 @@ struct ThreadwiseTensorSliceTransfer_v3r2
static_ford<remove_cvref_t<decltype(ordered_dst_access_lengths_tuple.At(dst_i))>>{}(
[&](auto ordered_dst_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] +
ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
constexpr auto forward_sweep = Helper::ComputeForwardSweep(
ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i));
// calculate dst data index
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i]
? ordered_dst_access_idx[i]
: ordered_dst_access_lengths_tuple.At(dst_i)[i] -
1 - ordered_dst_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access_tuple.At(dst_i);
}();
constexpr auto dst_data_idx =
Helper::ComputeDataIndex(ordered_dst_access_idx,
ordered_dst_access_lengths_tuple.At(dst_i),
forward_sweep,
dst_dim_access_order,
dst_scalar_per_access_tuple.At(dst_i));
constexpr auto dst_data_idx_seq =
generate_sequence_v2([&](auto i) { return Number<dst_data_idx[i]>{}; },
@@ -434,24 +341,10 @@ struct ThreadwiseTensorSliceTransfer_v3r2
dst_bufs.At(dst_i).template Update<DstInMemOp, dst_vector_t>(
dst_coords_.At(dst_i).GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[Helper::I0]);
constexpr auto move_on_dim = [&]() constexpr {
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] <
ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &=
ordered_dst_access_idx[j] ==
ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1;
});
});
return move_on_dim_;
}();
constexpr auto move_on_dim = Helper::ComputeMoveOnDim(
ordered_dst_access_idx, ordered_dst_access_lengths_tuple.At(dst_i));
// move dst coord
static_for<0, nDim, 1>{}([&](auto i) {
@@ -491,121 +384,19 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <index_t src_i>
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
// RunRead()
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
//
constexpr auto reset_src_data_step = [&]() {
Index reset_src_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
return reset_src_data_step_;
}();
return reset_src_data_step;
return Helper::ComputeCoordinateResetStep<SliceLengths,
SrcVectorDim,
SrcsScalarPerVector::At(src_i),
SrcDimAccessOrder>();
}
template <index_t dst_i>
__device__ static constexpr auto GetDstCoordinateResetStep()
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
constexpr auto ordered_dst_access_lengths =
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
// judge move forward or move backward during the last iteration
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_lengths[I0] - 1;
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr auto dst_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
});
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
dst_scalar_per_access.At(dst_i);
}();
//
constexpr auto reset_dst_data_step = [&]() {
Index reset_dst_data_step_;
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
return reset_dst_data_step_;
}();
return reset_dst_data_step;
return Helper::ComputeCoordinateResetStep<SliceLengths,
DstVectorDim,
DstsScalarPerVector::At(dst_i),
DstDimAccessOrder>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
@@ -649,103 +440,17 @@ struct ThreadwiseTensorSliceTransfer_v3r2
template <index_t src_i>
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcsScalarPerVector::At(src_i)>{},
Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length =
container_push_back(sequence_to_tuple_of_number(src_access_lengths),
Number<SrcsScalarPerVector::At(src_i)>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::ComputeThreadScratchDescriptor<SliceLengths,
SrcVectorDim,
SrcsScalarPerVector::At(src_i)>();
}
template <index_t dst_i>
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstsScalarPerVector::At(dst_i)>{},
Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length =
container_push_back(sequence_to_tuple_of_number(dst_access_lengths),
Number<DstsScalarPerVector::At(dst_i)>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return Helper::ComputeThreadScratchDescriptor<SliceLengths,
DstVectorDim,
DstsScalarPerVector::At(dst_i)>();
}
__device__ static constexpr auto MakeSrcThreadScratchTuple()

View File

@@ -42,8 +42,6 @@ struct ThreadwiseTensorSliceTransfer_v4r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{

View File

@@ -44,9 +44,6 @@ struct ThreadwiseTensorSliceTransfer_v5r1
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
__device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
@@ -40,11 +42,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
@@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[SFCHelper::I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
@@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
return SFCHelper::ComputeSFCCoordinateResetStep<SliceLengths,
DimAccessOrder,
decltype(scalar_per_access)>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep);
}
private:

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
@@ -39,11 +41,11 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
@@ -120,7 +122,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[SFCHelper::I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
@@ -156,52 +158,25 @@ struct ThreadwiseTensorSliceTransfer_v6r1r2
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
return SFCHelper::ComputeSFCCoordinateResetStep<SliceLengths,
DimAccessOrder,
decltype(scalar_per_access)>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<SrcDesc, SrcCoord, SrcResetCoordinateAfterRun>(
src_desc, src_coord_, src_slice_origin_step_idx, GetCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep);
}
private:

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
@@ -43,12 +45,12 @@ struct ThreadwiseTensorSliceTransfer_v6r2
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{}));
using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc,
const Index& src0_slice_origin,
const Src1Desc& src1_desc,
@@ -141,7 +143,7 @@ struct ThreadwiseTensorSliceTransfer_v6r2
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[SFCHelper::I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
@@ -187,67 +189,30 @@ struct ThreadwiseTensorSliceTransfer_v6r2
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
return SFCHelper::ComputeSFCCoordinateResetStep<SliceLengths,
DimAccessOrder,
decltype(scalar_per_access)>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc,
const Index& src0_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = Src0ResetCoordinateAfterRun
? src0_slice_origin_step_idx
: src0_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx);
move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<Src0Desc, Src0Coord, Src0ResetCoordinateAfterRun>(
src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc,
const Index& src1_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = Src1ResetCoordinateAfterRun
? src1_slice_origin_step_idx
: src1_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx);
move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<Src1Desc, Src1Coord, Src1ResetCoordinateAfterRun>(
src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep);
}
private:

View File

@@ -8,6 +8,8 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
@@ -46,13 +48,13 @@ struct ThreadwiseTensorSliceTransfer_v6r3
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{}));
using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{}));
using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
const Index& src0_slice_origin,
const Src1Desc& src1_desc,
@@ -165,7 +167,7 @@ struct ThreadwiseTensorSliceTransfer_v6r3
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
dst_vector_container.template AsType<dst_vector_t>()[SFCHelper::I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
@@ -221,82 +223,37 @@ struct ThreadwiseTensorSliceTransfer_v6r3
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
return SFCHelper::ComputeSFCCoordinateResetStep<SliceLengths,
DimAccessOrder,
decltype(scalar_per_access)>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc,
const Index& src0_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = Src0ResetCoordinateAfterRun
? src0_slice_origin_step_idx
: src0_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx);
move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<Src0Desc, Src0Coord, Src0ResetCoordinateAfterRun>(
src0_desc, src0_coord_, src0_slice_origin_step_idx, GetCoordinateResetStep);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc,
const Index& src1_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = Src1ResetCoordinateAfterRun
? src1_slice_origin_step_idx
: src1_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx);
move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<Src1Desc, Src1Coord, Src1ResetCoordinateAfterRun>(
src1_desc, src1_coord_, src1_slice_origin_step_idx, GetCoordinateResetStep);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc,
const Index& src2_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = Src2ResetCoordinateAfterRun
? src2_slice_origin_step_idx
: src2_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx);
move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<Src2Desc, Src2Coord, Src2ResetCoordinateAfterRun>(
src2_desc, src2_coord_, src2_slice_origin_step_idx, GetCoordinateResetStep);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
SFCHelper::MoveSliceWindow<DstDesc, DstCoord, DstResetCoordinateAfterRun>(
dst_desc, dst_coord_, dst_slice_origin_step_idx, GetCoordinateResetStep);
}
private:

View File

@@ -55,6 +55,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
@@ -124,17 +126,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
template <typename DataTypes, index_t ScalarPerVector>
__device__ static auto generate_vectors()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
return SFCHelper::MakeVectorContainerTuple<DataTypes, ScalarPerVector>();
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
@@ -473,98 +465,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector>();
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason

View File

@@ -56,6 +56,8 @@ struct ThreadwiseTensorSliceTransfer_v7r3
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
// return a tuple of coordiantes for a tuple of tensor
template <typename Descs,
typename Indices,
@@ -129,17 +131,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
template <typename DataTypes, index_t ScalarPerVector>
__device__ static auto generate_vectors()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
return SFCHelper::MakeVectorContainerTuple<DataTypes, ScalarPerVector>();
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
@@ -615,100 +607,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{},
// Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector>();
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{},
// Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason

View File

@@ -63,6 +63,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static constexpr index_t nDst = DstDescs::Size();
using Index = MultiIndex<nDim>;
using SFCHelper = ThreadwiseTransferHelper_SFC;
static constexpr index_t scatter_num = SliceLengths{}.At(Number<ScatterDim>{});
// return a tuple of coordiantes for a tuple of tensor
@@ -134,17 +135,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
template <typename DataTypes, index_t ScalarPerVector>
__device__ static auto generate_vectors()
{
auto data_types = DataTypes{};
constexpr index_t num = data_types.Size();
return generate_tuple(
[&](auto i) {
using DataType = remove_cvref_t<decltype(data_types[i])>;
return vector_type_maker_t<DataType, ScalarPerVector>{};
},
Number<num>{});
return SFCHelper::MakeVectorContainerTuple<DataTypes, ScalarPerVector>();
}
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
@@ -506,100 +497,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
__device__ static constexpr auto GetSrcThreadScratchDescriptor()
{
// constexpr auto src_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{},
// Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(src_access_lengths), Number<SrcScalarPerVector>{});
// 1st stage of transforms
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(src_access_lengths_and_vector_length[i],
src_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(src_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == SrcVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, SrcVectorDim, SrcScalarPerVector>();
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
// constexpr auto dst_scalar_per_access = generate_sequence(
// detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{},
// Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
constexpr auto dst_access_lengths_and_vector_length = container_push_back(
sequence_to_tuple_of_number(dst_access_lengths), Number<DstScalarPerVector>{});
constexpr auto desc0 =
make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length);
// 2nd stage of transforms
constexpr auto transforms = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return make_merge_transform_v3_division_mod(
make_tuple(dst_access_lengths_and_vector_length[i],
dst_access_lengths_and_vector_length[Number<nDim>{}]));
}
else
{
return make_pass_through_transform(dst_access_lengths_and_vector_length[i]);
}
},
Number<nDim>{});
constexpr auto low_dim_idss = generate_tuple(
[&](auto i) {
if constexpr(i == DstVectorDim)
{
return Sequence<i.value, nDim>{};
}
else
{
return Sequence<i.value>{};
}
},
Number<nDim>{});
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
return SFCHelper::
ComputeThreadScratchDescriptor<SliceLengths, DstVectorDim, DstScalarPerVector>();
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason

View File

@@ -255,6 +255,7 @@ add_compile_options(-Wno-c++20-extensions)
add_subdirectory(ck_tile)
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(threadwise_transfer_helper)
add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)

View File

@@ -0,0 +1,4 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_gtest_executable(test_threadwise_transfer_helper test_threadwise_transfer_helper.cpp)

View File

@@ -0,0 +1,748 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp"
using namespace ck;
// =============================================================================
// ThreadwiseTransferHelper_Base tests
// =============================================================================
TEST(ThreadwiseTransferHelperBase, CompileTimeConstants)
{
EXPECT_EQ(ThreadwiseTransferHelper_Base::I0.value, 0);
EXPECT_EQ(ThreadwiseTransferHelper_Base::I1.value, 1);
EXPECT_EQ(ThreadwiseTransferHelper_Base::I2.value, 2);
EXPECT_EQ(ThreadwiseTransferHelper_Base::I4.value, 4);
EXPECT_EQ(ThreadwiseTransferHelper_Base::I8.value, 8);
EXPECT_EQ(ThreadwiseTransferHelper_Base::I16.value, 16);
}
TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySerpentine)
{
// Serpentine inherits all constants from Base via public inheritance.
EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I0.value, 0);
EXPECT_EQ(ThreadwiseTransferHelper_Serpentine::I16.value, 16);
}
TEST(ThreadwiseTransferHelperBase, ConstantsInheritedBySFC)
{
// SFC inherits all constants from Base via public inheritance.
EXPECT_EQ(ThreadwiseTransferHelper_SFC::I0.value, 0);
EXPECT_EQ(ThreadwiseTransferHelper_SFC::I16.value, 16);
}
// =============================================================================
// ThreadwiseTransferHelper_Base::MoveSliceWindow tests
// =============================================================================
TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetAlreadyDone)
{
/*
* Scenario: v3r1's MoveSrcSliceWindow after RunRead has already reset
* the coordinate back to the slice origin (SrcResetCoordinateAfterRun=true).
*
* 2D packed tensor (4 rows x 8 columns), modelling a tile transfer:
*
* col: 0 1 2 3 4 5 6 7
* row 0: [*] . . . . . . . <-- start at (0,0), offset=0
* row 1: . . . . . . . .
* row 2: . . . . . . . .
* row 3: . . . . . . . .
*
* Step = (1, 0): move one row down.
* Reset step = (-3, 0): would move 3 rows up (irrelevant here).
*
* Since ResetCoordinateAfterRun=true, the reset step is NOT fused
* into the movement. The coordinate simply moves by the step alone.
*
* Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8
*/
using Helper = ThreadwiseTransferHelper_Base;
constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}));
auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0));
EXPECT_EQ(coord.GetOffset(), 0);
const auto step_idx = make_multi_index(1, 0);
auto get_reset_step = []() { return make_multi_index(-3, 0); };
Helper::MoveSliceWindow<decltype(desc), decltype(coord), true>(
desc, coord, step_idx, get_reset_step);
// Coordinate moved by step only: (0,0) -> (1,0)
// Offset in row-major packed layout: 1*8 + 0 = 8
EXPECT_EQ(coord.GetOffset(), 8);
}
TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_ResetFused)
{
/*
* Scenario: v3r1's MoveSrcSliceWindow when RunRead did NOT reset
* the coordinate (SrcResetCoordinateAfterRun=false). This is the
* optimization path where MoveSliceWindow fuses the reset step
* with the movement step to save a separate coordinate adjustment.
*
* Same 2D packed tensor (4 rows x 8 columns):
*
* col: 0 1 2 3 4 5 6 7
* row 0: [*] . . . . . . . <-- start at (0,0), offset=0
* row 1: . . . . . . . .
* row 2: . . . . . . . .
* row 3: . . . . . . . .
*
* Step = (2, 0): move two rows down.
* Reset step = (-1, 0): move one row up (e.g., undo traversal overshoot).
*
* Since ResetCoordinateAfterRun=false, MoveSliceWindow adds the
* reset step to the movement step before applying:
* adjusted_step = step + reset = (2,0) + (-1,0) = (1,0)
*
* Expected: (0,0) + (1,0) = (1,0), offset = 1*8 + 0 = 8
*/
using Helper = ThreadwiseTransferHelper_Base;
constexpr auto desc = make_naive_tensor_descriptor_packed(make_tuple(Number<4>{}, Number<8>{}));
auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0));
EXPECT_EQ(coord.GetOffset(), 0);
const auto step_idx = make_multi_index(2, 0);
auto get_reset_step = []() { return make_multi_index(-1, 0); };
Helper::MoveSliceWindow<decltype(desc), decltype(coord), false>(
desc, coord, step_idx, get_reset_step);
// adjusted_step = (2,0) + (-1,0) = (1,0)
// Offset: 1*8 + 0 = 8
EXPECT_EQ(coord.GetOffset(), 8);
}
TEST(ThreadwiseTransferHelperBase, MoveSliceWindow_3D_ResetFused)
{
/*
* Scenario: 3D packed tensor (2 x 4 x 8), modelling a typical GEMM
* intermediate buffer with SliceLengths = (batch, row, col).
*
* Layout (batch=0 shown, row-major packed):
*
* batch 0:
* col: 0 1 2 3 4 5 6 7
* row 0: . . . . . . . .
* row 1: . . . . . . . .
* row 2: . . . . . . . .
* row 3: . . . . . . . .
*
* batch 1: (same structure, offset += 4*8 = 32)
*
* Start at (0, 0, 0), offset=0.
*
* Step = (0, 2, 0): move 2 rows down within the same batch.
* Reset step = (0, -1, 0): undo 1 row of traversal overshoot.
*
* ResetCoordinateAfterRun=false, so steps are fused:
* adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0)
*
* Expected: (0,0,0) + (0,1,0) = (0,1,0)
* Offset in packed layout: 0*(4*8) + 1*8 + 0 = 8
*/
using Helper = ThreadwiseTransferHelper_Base;
constexpr auto desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<2>{}, Number<4>{}, Number<8>{}));
auto coord = make_tensor_coordinate(desc, make_multi_index(0, 0, 0));
EXPECT_EQ(coord.GetOffset(), 0);
const auto step_idx = make_multi_index(0, 2, 0);
auto get_reset_step = []() { return make_multi_index(0, -1, 0); };
Helper::MoveSliceWindow<decltype(desc), decltype(coord), false>(
desc, coord, step_idx, get_reset_step);
// adjusted_step = (0,2,0) + (0,-1,0) = (0,1,0)
// Offset: 0*32 + 1*8 + 0 = 8
EXPECT_EQ(coord.GetOffset(), 8);
}
// =============================================================================
// ThreadwiseTransferHelper_Serpentine::ComputeForwardSweep tests
// =============================================================================
TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_EvenRow)
{
/*
* 2D serpentine traversal on a 4x4 grid:
*
* dim1 ->
* 0 1 2 3
* +-->-->-->--+ row 0: forward (dim0=0 is even)
* +--<--<--<--+ row 1: backward (dim0=1 is odd)
* +-->-->-->--+ row 2: forward (dim0=2 is even)
* +--<--<--<--+ row 3: backward (dim0=3 is odd)
* dim0
*
* At position (0, *): dim0 is even -> dim1 sweeps FORWARD
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<0>{}, Number<0>{});
constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{});
constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths);
EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward (outermost)
EXPECT_TRUE(sweep[Number<1>{}]); // dim 1: forward because dim0 position (0) is even
}
TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_2D_OddRow)
{
/*
* Same 4x4 grid, but at row 1:
*
* +-->-->-->--+ row 0
* +--<--<--<--+ row 1: dim0=1 is odd -> dim1 sweeps BACKWARD
*
* At position (1, *): dim0 is odd -> dim1 sweeps BACKWARD
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<1>{}, Number<0>{});
constexpr auto lengths = make_tuple(Number<4>{}, Number<4>{});
constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths);
EXPECT_TRUE(sweep[Number<0>{}]); // dim 0: always forward
EXPECT_FALSE(sweep[Number<1>{}]); // dim 1: backward (dim0 position 1 is odd)
}
TEST(ThreadwiseTransferHelperSerpentine, ComputeForwardSweep_1D)
{
/*
* 1D traversal: always forward regardless of position.
*
* 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<3>{});
constexpr auto lengths = make_tuple(Number<8>{});
constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths);
EXPECT_TRUE(sweep[Number<0>{}]); // 1D: only dimension, always forward
}
// =============================================================================
// ThreadwiseTransferHelper_Serpentine::ComputeMoveOnDim tests
// =============================================================================
TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerNotComplete)
{
/*
* 2D grid with ordered_access_lengths = (3, 2):
*
* dim1: 0 1
* dim0:
* 0 [*] . <-- at (0,0): dim1 hasn't reached end yet
* 1 . .
* 2 . .
*
* Rule: a dimension moves only when all faster-varying (higher-index)
* dimensions have completed their range.
*
* At (0, 0):
* dim0: dim1 is at 0, not at end (1). -> dim0 does NOT move.
* dim1: no higher dims to check, and 0 < 1. -> dim1 MOVES.
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<0>{}, Number<0>{});
constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{});
constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths);
EXPECT_FALSE(move[Number<0>{}]); // dim 0: inner dim NOT at end
EXPECT_TRUE(move[Number<1>{}]); // dim 1: can advance
}
TEST(ThreadwiseTransferHelperSerpentine, ComputeMoveOnDim_InnerComplete)
{
/*
* Same grid, at position (0, 1):
*
* dim1: 0 1
* dim0:
* 0 . [*] <-- at (0,1): dim1 at its end (1 == 2-1)
* 1 . .
* 2 . .
*
* At (0, 1):
* dim0: dim1 is at end (1 == 1). dim0 < 2. -> dim0 MOVES.
* dim1: at end. -> dim1 does NOT move.
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<0>{}, Number<1>{});
constexpr auto lengths = make_tuple(Number<3>{}, Number<2>{});
constexpr auto move = Helper::ComputeMoveOnDim(idx, lengths);
EXPECT_TRUE(move[Number<0>{}]); // dim 0: inner dim at end, can advance
EXPECT_FALSE(move[Number<1>{}]); // dim 1: at its limit, cannot advance
}
// =============================================================================
// ThreadwiseTransferHelper_Serpentine::ComputeDataIndex tests
// =============================================================================
TEST(ThreadwiseTransferHelperSerpentine, ComputeDataIndex_ForwardSweep)
{
/*
* 2D grid (4x3), both dims sweeping forward, identity order, scale=1:
*
* ordered_access_idx = (2, 1)
* forward_sweep = (true, true)
* dim_access_order = (0, 1) <-- identity
* scalar_per_access = (1, 1) <-- no scaling
*
* Forward: data_idx = ordered_idx = (2, 1)
* Reorder: identity -> (2, 1)
* Scale: * (1,1) -> (2, 1)
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto idx = make_tuple(Number<2>{}, Number<1>{});
constexpr auto lengths = make_tuple(Number<4>{}, Number<3>{});
constexpr auto sweep = Helper::ComputeForwardSweep(idx, lengths);
constexpr auto order = Sequence<0, 1>{};
constexpr auto spa = Sequence<1, 1>{};
constexpr auto data_idx = Helper::ComputeDataIndex(idx, lengths, sweep, order, spa);
EXPECT_EQ(data_idx[Number<0>{}], 2);
EXPECT_EQ(data_idx[Number<1>{}], 1);
}
// =============================================================================
// ThreadwiseTransferHelper_Serpentine::ComputeCoordinateResetStep tests
// =============================================================================
TEST(ThreadwiseTransferHelperSerpentine, ComputeCoordinateResetStep_2D)
{
/*
* SliceLengths = (4, 2), VectorDim = 1, ScalarPerVector = 2
* DimAccessOrder = (0, 1)
*
* scalar_per_access = (1, 2) [only dim 1 is vectorized with width 2]
* access_lengths = (4, 1) [4/1=4, 2/2=1]
*
* The traversal visits 4 positions along dim 0, each accessing 2 elements:
*
* dim0=0: access [0,0..1]
* dim0=1: access [1,0..1] (backward sweep, but only 1 step on dim1)
* dim0=2: access [2,0..1]
* dim0=3: access [3,0..1]
*
* Final position: data_idx = (3, 0) * scalar_per_access = (3, 0)
* Reset step: -(3, 0) = (-3, 0)
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
constexpr auto reset =
Helper::ComputeCoordinateResetStep<Sequence<4, 2>, 1, 2, Sequence<0, 1>>();
EXPECT_EQ(reset[Number<0>{}], -3);
EXPECT_EQ(reset[Number<1>{}], 0);
}
// =============================================================================
// VectorSizeLookupTable / VectorOffsetsLookupTable tests
// =============================================================================
TEST(ThreadwiseTransferHelperSerpentine, VectorSizeLookupTable)
{
/*
* Binary decomposition of vector widths into power-of-2 sub-loads:
*
* Width 0: (empty) -- no loads
* Width 1: {1} -- single 1-wide load
* Width 7: {4, 2, 1} -- 4+2+1 = 7
* Width 8: {8} -- single 8-wide load
* Width 16: {16} -- single 16-wide load
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
using VecSize0 = tuple_element_t<0, Helper::VectorSizeLookupTable>;
using VecSize1 = tuple_element_t<1, Helper::VectorSizeLookupTable>;
using VecSize7 = tuple_element_t<7, Helper::VectorSizeLookupTable>;
using VecSize8 = tuple_element_t<8, Helper::VectorSizeLookupTable>;
using VecSize16 = tuple_element_t<16, Helper::VectorSizeLookupTable>;
EXPECT_EQ(VecSize0::Size(), 0);
EXPECT_EQ(VecSize1::Size(), 1);
EXPECT_EQ(VecSize1::At(0), 1);
EXPECT_EQ(VecSize7::Size(), 3);
EXPECT_EQ(VecSize7::At(0), 4); // first sub-load: 4 elements
EXPECT_EQ(VecSize7::At(1), 2); // second sub-load: 2 elements
EXPECT_EQ(VecSize7::At(2), 1); // third sub-load: 1 element
EXPECT_EQ(VecSize8::Size(), 1);
EXPECT_EQ(VecSize8::At(0), 8);
EXPECT_EQ(VecSize16::Size(), 1);
EXPECT_EQ(VecSize16::At(0), 16);
}
TEST(ThreadwiseTransferHelperSerpentine, VectorOffsetsLookupTable)
{
/*
* Starting element offsets for each sub-load in the decomposition:
*
* Width 7 = {4, 2, 1}:
* |<--- 4 --->|<- 2 ->|1|
* offset 0 offset 4 offset 6
*
* So offsets = {0, 4, 6}
*/
using Helper = ThreadwiseTransferHelper_Serpentine;
using VecOff7 = tuple_element_t<7, Helper::VectorOffsetsLookupTable>;
EXPECT_EQ(VecOff7::Size(), 3);
EXPECT_EQ(VecOff7::At(0), 0); // first sub-load starts at offset 0
EXPECT_EQ(VecOff7::At(1), 4); // second sub-load starts at offset 4
EXPECT_EQ(VecOff7::At(2), 6); // third sub-load starts at offset 6
}
// =============================================================================
// ThreadwiseTransferHelper_SFC tests
// =============================================================================
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_SingleAccess)
{
/*
* SliceLengths = (1, 1), ScalarPerAccess = (1, 1)
* Only 1 access position total -> already at origin, reset = (0, 0)
*
* [*] <-- single element, no movement needed
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 1>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<1, 1>,
Sequence<0, 1>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], 0);
EXPECT_EQ(reset[Number<1>{}], 0);
}
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_RowMajor)
{
/*
* Typical v6r1 scenario: 2D slice transfer with vectorized column access.
*
* SliceLengths = (4, 8) -- 4 rows, 8 columns
* DimAccessOrder = (0, 1) -- row-major traversal (rows change slowest)
* ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns
*
* access_lengths = SliceLengths / ScalarPerAccess = (4, 2)
*
* The SFC traverses in serpentine order through 4*2 = 8 access positions:
*
* col: 0..3 4..7
* row 0: [0]-->[1] access 0 -> idx (0,0), access 1 -> idx (0,4)
* row 1: [3]<--[2] access 2 -> idx (1,4), access 3 -> idx (1,0)
* row 2: [4]-->[5] access 4 -> idx (2,0), access 5 -> idx (2,4)
* row 3: [7]<--[6] access 6 -> idx (3,4), access 7 -> idx (3,0)
*
* Last access (#7) lands at index (3, 0).
* Reset step = origin - last = (0,0) - (3,0) = (-3, 0)
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 4>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<4, 8>,
Sequence<0, 1>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], -3); // return 3 rows up
EXPECT_EQ(reset[Number<1>{}], 0); // column already at origin
}
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_2D_ColMajor)
{
/*
* Same 2D slice but column-major traversal order.
*
* SliceLengths = (4, 8) -- 4 rows, 8 columns
* DimAccessOrder = (1, 0) -- column-major (columns change slowest)
* ScalarPerAccess = (1, 4) -- 4-wide vector loads along columns
*
* access_lengths = (4, 2)
* ordered_access_lengths = reorder_new2old((4,2), (1,0)) = (2, 4)
* (dim 1 is the "slow" outer dimension, dim 0 is the "fast" inner)
*
* Traversal (ordered dims are [col_block, row]):
*
* col_block: 0 1
* row 0: [0] [7]
* row 1: [1] [6]
* row 2: [2] [5]
* row 3: [3] [4]
*
* Unordered indices (natural dim order):
* access 0 -> (row=0, col=0*4=0)
* access 3 -> (row=3, col=0)
* access 4 -> (row=3, col=1*4=4) (serpentine reversal in row)
* access 7 -> (row=0, col=4)
*
* Last access (#7) lands at index (0, 4).
* Reset step = (0,0) - (0,4) = (0, -4)
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 4>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<4, 8>,
Sequence<1, 0>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], 0); // row already at origin
EXPECT_EQ(reset[Number<1>{}], -4); // return 4 columns left
}
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_3D)
{
/*
* 3D slice transfer, modelling a batch x row x col tile as used in
* batched GEMM or attention kernels (v7r2/v7r3).
*
* SliceLengths = (2, 4, 8) -- 2 batches, 4 rows, 8 columns
* DimAccessOrder = (0, 1, 2) -- batch outermost, column innermost
* ScalarPerAccess = (1, 1, 8) -- 8-wide vector loads on columns
*
* access_lengths = (2, 4, 1)
* Total accesses = 2 * 4 * 1 = 8
*
* Traversal within each batch is serpentine on rows, columns scalar:
*
* batch 0:
* row 0: [0] -- (0, 0, 0)
* row 1: [1] -- (0, 1, 0)
* row 2: [2] -- (0, 2, 0)
* row 3: [3] -- (0, 3, 0)
*
* batch 1: (serpentine reversal on rows)
* row 3: [4] -- (1, 3, 0)
* row 2: [5] -- (1, 2, 0)
* row 1: [6] -- (1, 1, 0)
* row 0: [7] -- (1, 0, 0)
*
* Last access (#7) lands at index (1, 0, 0).
* Reset step = (0,0,0) - (1,0,0) = (-1, 0, 0)
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 1, 8>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<2, 4, 8>,
Sequence<0, 1, 2>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], -1); // return 1 batch
EXPECT_EQ(reset[Number<1>{}], 0); // row already at origin (serpentine came back)
EXPECT_EQ(reset[Number<2>{}], 0); // column at origin (single access per row)
}
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_EvenInnerAccesses)
{
/*
* When the number of accesses along the inner dimension is even, the
* serpentine traversal returns to the starting side on that dimension.
*
* SliceLengths = (4, 4)
* DimAccessOrder = (0, 1)
* ScalarPerAccess = (1, 2) -- 2-wide vector loads
*
* access_lengths = (4, 2) -- 2 accesses along cols (even)
*
* col: 0..1 2..3
* row 0: [0]-->[1] access 0 -> (0,0), access 1 -> (0,2)
* row 1: [3]<--[2] access 2 -> (1,2), access 3 -> (1,0)
* row 2: [4]-->[5] access 4 -> (2,0), access 5 -> (2,2)
* row 3: [7]<--[6] access 6 -> (3,2), access 7 -> (3,0)
*
* Last access (#7) at (3, 0). Even number of column accesses (2)
* means the serpentine always returns to col=0 at the end of each row.
* Reset step = (0,0) - (3,0) = (-3, 0)
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 2>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<4, 4>,
Sequence<0, 1>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], -3);
EXPECT_EQ(reset[Number<1>{}], 0); // even inner accesses -> back at start column
}
TEST(ThreadwiseTransferHelperSFC, ComputeSFCCoordinateResetStep_OddInnerAccesses)
{
/*
* When the number of accesses along the inner dimension is odd and the
* outer dimension is even, the serpentine returns to col=0.
*
* SliceLengths = (2, 6)
* DimAccessOrder = (0, 1)
* ScalarPerAccess = (1, 2) -- 2-wide vector loads
*
* access_lengths = (2, 3) -- 3 accesses along cols (odd!)
*
* col: 0..1 2..3 4..5
* row 0: [0]-->[1]-->[2] access 0 -> (0,0), 1 -> (0,2), 2 -> (0,4)
* row 1: [5]<--[4]<--[3] access 3 -> (1,4), 4 -> (1,2), 5 -> (1,0)
*
* Last access (#5) at (1, 0). Even row count means serpentine reversal
* on the inner dim brings us back to col=0.
* Reset step = (0,0) - (1,0) = (-1, 0)
*/
using SFCHelper = ThreadwiseTransferHelper_SFC;
constexpr auto scalar_per_access = Sequence<1, 2>{};
constexpr auto reset = SFCHelper::ComputeSFCCoordinateResetStep<Sequence<2, 6>,
Sequence<0, 1>,
decltype(scalar_per_access)>();
EXPECT_EQ(reset[Number<0>{}], -1); // return 1 row
EXPECT_EQ(reset[Number<1>{}], 0); // even outer accesses -> serpentine came back to col=0
}
// =============================================================================
// Inheritance structure tests
// =============================================================================
TEST(ThreadwiseTransferHelperInheritance, SerpentineIsDerivedFromBase)
{
/*
* ThreadwiseTransferHelper_Base
* |
* +-- ThreadwiseTransferHelper_Serpentine <-- this relationship
* |
* +-- ThreadwiseTransferHelper_SFC
*/
static_assert(
std::is_base_of_v<ThreadwiseTransferHelper_Base, ThreadwiseTransferHelper_Serpentine>);
}
TEST(ThreadwiseTransferHelperInheritance, SFCIsDerivedFromBase)
{
/*
* ThreadwiseTransferHelper_Base
* |
* +-- ThreadwiseTransferHelper_Serpentine
* |
* +-- ThreadwiseTransferHelper_SFC <-- this relationship
*/
static_assert(std::is_base_of_v<ThreadwiseTransferHelper_Base, ThreadwiseTransferHelper_SFC>);
}
TEST(ThreadwiseTransferHelperInheritance, SerpentineAndSFCAreNotRelated)
{
/*
* Serpentine and SFC are siblings -- neither inherits from the other.
*
* ThreadwiseTransferHelper_Base
* |
* +-- Serpentine (NOT parent of SFC)
* |
* +-- SFC (NOT parent of Serpentine)
*/
static_assert(
!std::is_base_of_v<ThreadwiseTransferHelper_Serpentine, ThreadwiseTransferHelper_SFC>);
static_assert(
!std::is_base_of_v<ThreadwiseTransferHelper_SFC, ThreadwiseTransferHelper_Serpentine>);
}
// =============================================================================
// detail:: functor tests
// =============================================================================
TEST(DetailFunctors, LambdaScalarPerAccess)
{
/*
* For VectorDim=1 and ScalarPerVector=8:
*
* dim: 0 1 2
* result: 1 8 1
* ^ ^ ^
* | | +-- not the vector dim
* | +------ THE vector dim (returns ScalarPerVector)
* +---------- not the vector dim
*/
constexpr auto f = detail::lambda_scalar_per_access<1, 8>{};
EXPECT_EQ(f(0), 1);
EXPECT_EQ(f(1), 8);
EXPECT_EQ(f(2), 1);
}
TEST(DetailFunctors, LambdaScalarStepInVector)
{
/*
* For VectorDim=2:
*
* dim: 0 1 2 3
* result: 0 0 1 0
* ^
* +-- THE vector dim (step = 1)
*/
constexpr auto f = detail::lambda_scalar_step_in_vector<2>{};
EXPECT_EQ(f(0), 0);
EXPECT_EQ(f(1), 0);
EXPECT_EQ(f(2), 1);
EXPECT_EQ(f(3), 0);
}
TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_SameDim)
{
/*
* Src and Dst both vectorize dim 1:
* SrcVectorDim=1, SrcScalarPerVector=4
* DstVectorDim=1, DstScalarPerVector=8
*
* dim: 0 1 2
* result: 1 lcm(4,8) 1
* = 8
*/
constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<1, 4, 1, 8>{};
EXPECT_EQ(f(0), 1);
EXPECT_EQ(f(1), 8); // lcm(4, 8) = 8
EXPECT_EQ(f(2), 1);
}
TEST(DetailFunctors, LambdaScalarPerAccessForSrcAndDst_DifferentDims)
{
/*
* Src vectorizes dim 0 (width 4), Dst vectorizes dim 2 (width 8):
*
* dim: 0 1 2
* result: 4(src) 1 8(dst)
*/
constexpr auto f = detail::lambda_scalar_per_access_for_src_and_dst<0, 4, 2, 8>{};
EXPECT_EQ(f(0), 4); // src vector dim
EXPECT_EQ(f(1), 1); // neither
EXPECT_EQ(f(2), 8); // dst vector dim
}