[rocm-libraries] ROCm/rocm-libraries#5041 (commit 481aecc)

[CK] Precompute SpaceFillingCurve indices to reduce compile
 time by 31% (#5041)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Summary

Optimize `SpaceFillingCurve` in CK to reduce compile time by
precomputing all index values into a static constexpr lookup table.

### Problem
- `GetIndex<N>` was instantiated separately for every index value (0 to
NumAccesses-1)
- Each instantiation triggered nested `static_for` loops with O(N²)
template depth
- This caused **34,000+ template instantiations** taking **69 seconds**
in frontend

### Solution
- Add `IndexLookupTable<NumAccesses, nDim>` to store all precomputed
indices
- Add `compute_single_index()` helper using O(N) `static_for` loops
- Add `compute_all_indices()` to build entire table in one constexpr
evaluation
- `GetIndex<N>` becomes simple array lookup: `return index_table[N]`

### Results (conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)

| Metric | Before | After | Improvement |
|--------|--------|-------|-------------|
| Total compile time | 120.4s | 83.6s | **-31%** |
| Frontend time | 88.7s | 52.6s | **-41%** |
| GetIndex instantiations | 34,176 | 384 | **-99%** |
| GetIndex time | 69.0s | 0.11s | **-99.8%** |
| SpaceFillingCurve time | 75.7s | 4.3s | **-94%** |

## Test plan
- [x] Builds successfully with `-Werror -Weverything`
- [ ] Run existing unit tests
- [ ] Verify numerical correctness on sample kernels

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Max Podkorytov
2026-03-10 19:41:40 +00:00
committed by assistant-librarian[bot]
parent 51537eb189
commit b8def2c724

View File

@@ -12,6 +12,89 @@
namespace ck {
namespace detail {
// Lookup table to store precomputed indices for all 1D access values
template <index_t NumAccesses, index_t nDim>
struct IndexLookupTable
{
MultiIndex<nDim> data[NumAccesses > 0 ? NumAccesses : 1];
__host__ __device__ constexpr const MultiIndex<nDim>& operator[](index_t i) const
{
return data[i];
}
};
// Compute a single index given 1D access index - used internally during table construction
// Uses static_for to work with MultiIndex which requires Number<I> for indexing
template <index_t nDim, bool SnakeCurved, typename Strides, typename OrderedAccessLengths>
__host__ __device__ constexpr auto
compute_single_index(index_t idx_1d, Strides strides, OrderedAccessLengths ordered_lengths)
{
// Step 1: Convert 1D index to N-D ordered coordinates using strides
MultiIndex<nDim> ordered_access_idx;
index_t remaining = idx_1d;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_access_idx(i) = remaining / strides[i];
remaining = remaining % strides[i];
});
// Step 2: Compute forward_sweep - whether each dimension is in forward direction
StaticallyIndexedArray<bool, nDim> forward_sweep;
forward_sweep(Number<0>{}) = true;
index_t cumulative = ordered_access_idx[Number<0>{}];
static_for<1, nDim, 1>{}([&](auto i) {
forward_sweep(i) = cumulative % 2 == 0;
cumulative = cumulative * ordered_lengths[i] + ordered_access_idx[i];
});
// Step 3: Apply snake curve transformation
MultiIndex<nDim> ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
if(!SnakeCurved || forward_sweep[i])
{
ordered_idx(i) = ordered_access_idx[i];
}
else
{
ordered_idx(i) = ordered_lengths[i] - 1 - ordered_access_idx[i];
}
});
return ordered_idx;
}
// Precompute all indices into a lookup table using a single constexpr loop
template <index_t NumAccesses,
index_t nDim,
bool SnakeCurved,
typename Strides,
typename OrderedAccessLengths,
typename DimAccessOrder,
typename ScalarsPerAccess>
__host__ __device__ constexpr auto compute_all_indices(Strides strides,
OrderedAccessLengths ordered_lengths,
DimAccessOrder dim_order,
ScalarsPerAccess scalars)
{
IndexLookupTable<NumAccesses, nDim> table{};
for(index_t idx_1d = 0; idx_1d < NumAccesses; ++idx_1d)
{
auto ordered_idx =
compute_single_index<nDim, SnakeCurved>(idx_1d, strides, ordered_lengths);
// Reorder and scale
auto reordered = container_reorder_given_old2new(ordered_idx, dim_order);
table.data[idx_1d] = reordered * scalars;
}
return table;
}
} // namespace detail
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess,
@@ -30,6 +113,18 @@ struct SpaceFillingCurve
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
// Precompute access strides at class level
static constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, math::multiplies{}, Number<1>{});
// Number of access indices
static constexpr index_t NumAccesses =
reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / ScalarPerVector;
// Precompute ALL indices into a lookup table - computed once at class instantiation
static constexpr auto index_table = detail::compute_all_indices<NumAccesses, nDim, SnakeCurved>(
access_strides, ordered_access_lengths, dim_access_order, ScalarsPerAccess{});
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
@@ -80,71 +175,9 @@ struct SpaceFillingCurve
template <index_t AccessIdx1d>
static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
{
#if 0
/*
* \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
#else
constexpr auto access_strides = container_reverse_exclusive_scan(
ordered_access_lengths, math::multiplies{}, Number<1>{});
constexpr auto idx_1d = Number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return Number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, Number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
static_assert(AccessIdx1d >= 0 && AccessIdx1d < NumAccesses, "Index out of bounds");
// Simple lookup from precomputed table - O(1) with no template instantiation overhead
return index_table[AccessIdx1d];
}
// FIXME: rename this function