From 60231a71b49b08e270d978b0756bd9c43ccc2293 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 10 Mar 2026 12:40:08 -0700 Subject: [PATCH] [CK] Precompute SpaceFillingCurve indices to reduce compile time by 31% (#5041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Optimize `SpaceFillingCurve` in CK to reduce compile time by precomputing all index values into a static constexpr lookup table. ### Problem - `GetIndex` was instantiated separately for every index value (0 to NumAccesses-1) - Each instantiation triggered nested `static_for` loops with O(N²) template depth - This caused **34,000+ template instantiations** taking **69 seconds** in frontend ### Solution - Add `IndexLookupTable` to store all precomputed indices - Add `compute_single_index()` helper using O(N) `static_for` loops - Add `compute_all_indices()` to build entire table in one constexpr evaluation - `GetIndex` becomes simple array lookup: `return index_table[N]` ### Results (conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) | Metric | Before | After | Improvement | |--------|--------|-------|-------------| | Total compile time | 120.4s | 83.6s | **-31%** | | Frontend time | 88.7s | 52.6s | **-41%** | | GetIndex instantiations | 34,176 | 384 | **-99%** | | GetIndex time | 69.0s | 0.11s | **-99.8%** | | SpaceFillingCurve time | 75.7s | 4.3s | **-94%** | ## Test plan - [x] Builds successfully with `-Werror -Weverything` - [ ] Run existing unit tests - [ ] Verify numerical correctness on sample kernels 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> --- .../tensor_space_filling_curve.hpp | 163 +++++++++++------- 1 file changed, 98 insertions(+), 65 deletions(-) diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index a3b44bdf0b..c399c874b8 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -12,6 +12,89 @@ namespace ck { +namespace detail { + +// Lookup table to store precomputed indices for all 1D access values +template +struct IndexLookupTable +{ + MultiIndex data[NumAccesses > 0 ? NumAccesses : 1]; + + __host__ __device__ constexpr const MultiIndex& operator[](index_t i) const + { + return data[i]; + } +}; + +// Compute a single index given 1D access index - used internally during table construction +// Uses static_for to work with MultiIndex which requires Number for indexing +template +__host__ __device__ constexpr auto +compute_single_index(index_t idx_1d, Strides strides, OrderedAccessLengths ordered_lengths) +{ + // Step 1: Convert 1D index to N-D ordered coordinates using strides + MultiIndex ordered_access_idx; + index_t remaining = idx_1d; + static_for<0, nDim, 1>{}([&](auto i) { + ordered_access_idx(i) = remaining / strides[i]; + remaining = remaining % strides[i]; + }); + + // Step 2: Compute forward_sweep - whether each dimension is in forward direction + StaticallyIndexedArray forward_sweep; + forward_sweep(Number<0>{}) = true; + index_t cumulative = ordered_access_idx[Number<0>{}]; + static_for<1, nDim, 1>{}([&](auto i) { + forward_sweep(i) = cumulative % 2 == 0; + cumulative = cumulative * ordered_lengths[i] + ordered_access_idx[i]; + }); + + // Step 3: Apply snake curve transformation + MultiIndex ordered_idx; + static_for<0, nDim, 1>{}([&](auto i) { + if(!SnakeCurved || forward_sweep[i]) + { + ordered_idx(i) = ordered_access_idx[i]; + } + else + { + ordered_idx(i) = ordered_lengths[i] - 1 - ordered_access_idx[i]; + } + }); + + return ordered_idx; +} + +// Precompute all indices into a lookup table using a single constexpr loop +template +__host__ __device__ constexpr auto compute_all_indices(Strides strides, + OrderedAccessLengths ordered_lengths, + DimAccessOrder dim_order, + ScalarsPerAccess scalars) +{ + IndexLookupTable table{}; + + for(index_t idx_1d = 0; idx_1d < NumAccesses; ++idx_1d) + { + auto ordered_idx = + compute_single_index(idx_1d, strides, ordered_lengths); + + // Reorder and scale + auto reordered = container_reorder_given_old2new(ordered_idx, dim_order); + table.data[idx_1d] = reordered * scalars; + } + + return table; +} + +} // namespace detail + template {}); + + // Number of access indices + static constexpr index_t NumAccesses = + reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / ScalarPerVector; + + // Precompute ALL indices into a lookup table - computed once at class instantiation + static constexpr auto index_table = detail::compute_all_indices( + access_strides, ordered_access_lengths, dim_access_order, ScalarsPerAccess{}); + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(ordered_access_lengths)), make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), @@ -80,71 +175,9 @@ struct SpaceFillingCurve template static __device__ __host__ constexpr Index GetIndex(Number) { -#if 0 - /* - * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected. - */ - constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number{})); -#else - - constexpr auto access_strides = container_reverse_exclusive_scan( - ordered_access_lengths, math::multiplies{}, Number<1>{}); - - constexpr auto idx_1d = Number{}; - // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the - // idim-th element of multidimensional index. - // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { - constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { - auto res = idx_1d.value; - auto id = 0; - - static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { - id = res / access_strides[kdim].value; - res -= id * access_strides[kdim].value; - }); - - return id; - }; - - constexpr auto id = compute_index_impl(idim); - return Number{}; - }; - - constexpr auto ordered_access_idx = generate_tuple(compute_index, Number{}); -#endif - constexpr auto forward_sweep = [&]() { - StaticallyIndexedArray forward_sweep_; - - forward_sweep_(I0) = true; - - static_for<1, nDim, 1>{}([&](auto idim) { - index_t tmp = ordered_access_idx[I0]; - - static_for<1, idim, 1>{}( - [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); - - forward_sweep_(idim) = tmp % 2 == 0; - }); - - return forward_sweep_; - }(); - - // calculate multi-dim tensor index - auto idx_md = [&]() { - Index ordered_idx; - - static_for<0, nDim, 1>{}([&](auto idim) { - ordered_idx(idim) = - !SnakeCurved || forward_sweep[idim] - ? ordered_access_idx[idim] - : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; - }); - - return container_reorder_given_old2new(ordered_idx, dim_access_order) * - ScalarsPerAccess{}; - }(); - return idx_md; + static_assert(AccessIdx1d >= 0 && AccessIdx1d < NumAccesses, "Index out of bounds"); + // Simple lookup from precomputed table - O(1) with no template instantiation overhead + return index_table[AccessIdx1d]; } // FIXME: rename this function