mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#5041 (commit 481aecc)
[CK] Precompute SpaceFillingCurve indices to reduce compile time by 31% (#5041) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Optimize `SpaceFillingCurve` in CK to reduce compile time by precomputing all index values into a static constexpr lookup table. ### Problem - `GetIndex<N>` was instantiated separately for every index value (0 to NumAccesses-1) - Each instantiation triggered nested `static_for` loops with O(N²) template depth - This caused **34,000+ template instantiations** taking **69 seconds** in frontend ### Solution - Add `IndexLookupTable<NumAccesses, nDim>` to store all precomputed indices - Add `compute_single_index()` helper using O(N) `static_for` loops - Add `compute_all_indices()` to build entire table in one constexpr evaluation - `GetIndex<N>` becomes simple array lookup: `return index_table[N]` ### Results (conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp) | Metric | Before | After | Improvement | |--------|--------|-------|-------------| | Total compile time | 120.4s | 83.6s | **-31%** | | Frontend time | 88.7s | 52.6s | **-41%** | | GetIndex instantiations | 34,176 | 384 | **-99%** | | GetIndex time | 69.0s | 0.11s | **-99.8%** | | SpaceFillingCurve time | 75.7s | 4.3s | **-94%** | ## Test plan - [x] Builds successfully with `-Werror -Weverything` - [ ] Run existing unit tests - [ ] Verify numerical correctness on sample kernels 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
51537eb189
commit
b8def2c724
@@ -12,6 +12,89 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Lookup table to store precomputed indices for all 1D access values
|
||||
template <index_t NumAccesses, index_t nDim>
|
||||
struct IndexLookupTable
|
||||
{
|
||||
MultiIndex<nDim> data[NumAccesses > 0 ? NumAccesses : 1];
|
||||
|
||||
__host__ __device__ constexpr const MultiIndex<nDim>& operator[](index_t i) const
|
||||
{
|
||||
return data[i];
|
||||
}
|
||||
};
|
||||
|
||||
// Compute a single index given 1D access index - used internally during table construction
|
||||
// Uses static_for to work with MultiIndex which requires Number<I> for indexing
|
||||
template <index_t nDim, bool SnakeCurved, typename Strides, typename OrderedAccessLengths>
|
||||
__host__ __device__ constexpr auto
|
||||
compute_single_index(index_t idx_1d, Strides strides, OrderedAccessLengths ordered_lengths)
|
||||
{
|
||||
// Step 1: Convert 1D index to N-D ordered coordinates using strides
|
||||
MultiIndex<nDim> ordered_access_idx;
|
||||
index_t remaining = idx_1d;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_access_idx(i) = remaining / strides[i];
|
||||
remaining = remaining % strides[i];
|
||||
});
|
||||
|
||||
// Step 2: Compute forward_sweep - whether each dimension is in forward direction
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
forward_sweep(Number<0>{}) = true;
|
||||
index_t cumulative = ordered_access_idx[Number<0>{}];
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
forward_sweep(i) = cumulative % 2 == 0;
|
||||
cumulative = cumulative * ordered_lengths[i] + ordered_access_idx[i];
|
||||
});
|
||||
|
||||
// Step 3: Apply snake curve transformation
|
||||
MultiIndex<nDim> ordered_idx;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(!SnakeCurved || forward_sweep[i])
|
||||
{
|
||||
ordered_idx(i) = ordered_access_idx[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
ordered_idx(i) = ordered_lengths[i] - 1 - ordered_access_idx[i];
|
||||
}
|
||||
});
|
||||
|
||||
return ordered_idx;
|
||||
}
|
||||
|
||||
// Precompute all indices into a lookup table using a single constexpr loop
|
||||
template <index_t NumAccesses,
|
||||
index_t nDim,
|
||||
bool SnakeCurved,
|
||||
typename Strides,
|
||||
typename OrderedAccessLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess>
|
||||
__host__ __device__ constexpr auto compute_all_indices(Strides strides,
|
||||
OrderedAccessLengths ordered_lengths,
|
||||
DimAccessOrder dim_order,
|
||||
ScalarsPerAccess scalars)
|
||||
{
|
||||
IndexLookupTable<NumAccesses, nDim> table{};
|
||||
|
||||
for(index_t idx_1d = 0; idx_1d < NumAccesses; ++idx_1d)
|
||||
{
|
||||
auto ordered_idx =
|
||||
compute_single_index<nDim, SnakeCurved>(idx_1d, strides, ordered_lengths);
|
||||
|
||||
// Reorder and scale
|
||||
auto reordered = container_reorder_given_old2new(ordered_idx, dim_order);
|
||||
table.data[idx_1d] = reordered * scalars;
|
||||
}
|
||||
|
||||
return table;
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename TensorLengths,
|
||||
typename DimAccessOrder,
|
||||
typename ScalarsPerAccess,
|
||||
@@ -30,6 +113,18 @@ struct SpaceFillingCurve
|
||||
static constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
// Precompute access strides at class level
|
||||
static constexpr auto access_strides =
|
||||
container_reverse_exclusive_scan(ordered_access_lengths, math::multiplies{}, Number<1>{});
|
||||
|
||||
// Number of access indices
|
||||
static constexpr index_t NumAccesses =
|
||||
reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / ScalarPerVector;
|
||||
|
||||
// Precompute ALL indices into a lookup table - computed once at class instantiation
|
||||
static constexpr auto index_table = detail::compute_all_indices<NumAccesses, nDim, SnakeCurved>(
|
||||
access_strides, ordered_access_lengths, dim_access_order, ScalarsPerAccess{});
|
||||
|
||||
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(ordered_access_lengths)),
|
||||
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
|
||||
@@ -80,71 +175,9 @@ struct SpaceFillingCurve
|
||||
template <index_t AccessIdx1d>
|
||||
static __device__ __host__ constexpr Index GetIndex(Number<AccessIdx1d>)
|
||||
{
|
||||
#if 0
|
||||
/*
|
||||
* \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
|
||||
*/
|
||||
constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
|
||||
#else
|
||||
|
||||
constexpr auto access_strides = container_reverse_exclusive_scan(
|
||||
ordered_access_lengths, math::multiplies{}, Number<1>{});
|
||||
|
||||
constexpr auto idx_1d = Number<AccessIdx1d>{};
|
||||
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
|
||||
// idim-th element of multidimensional index.
|
||||
// All constexpr variables have to be captured by VALUE.
|
||||
constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr {
|
||||
constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr {
|
||||
auto res = idx_1d.value;
|
||||
auto id = 0;
|
||||
|
||||
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
|
||||
id = res / access_strides[kdim].value;
|
||||
res -= id * access_strides[kdim].value;
|
||||
});
|
||||
|
||||
return id;
|
||||
};
|
||||
|
||||
constexpr auto id = compute_index_impl(idim);
|
||||
return Number<id>{};
|
||||
};
|
||||
|
||||
constexpr auto ordered_access_idx = generate_tuple(compute_index, Number<nDim>{});
|
||||
#endif
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto idim) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
|
||||
static_for<1, idim, 1>{}(
|
||||
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
|
||||
|
||||
forward_sweep_(idim) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate multi-dim tensor index
|
||||
auto idx_md = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto idim) {
|
||||
ordered_idx(idim) =
|
||||
!SnakeCurved || forward_sweep[idim]
|
||||
? ordered_access_idx[idim]
|
||||
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
|
||||
});
|
||||
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
ScalarsPerAccess{};
|
||||
}();
|
||||
return idx_md;
|
||||
static_assert(AccessIdx1d >= 0 && AccessIdx1d < NumAccesses, "Index out of bounds");
|
||||
// Simple lookup from precomputed table - O(1) with no template instantiation overhead
|
||||
return index_table[AccessIdx1d];
|
||||
}
|
||||
|
||||
// FIXME: rename this function
|
||||
|
||||
Reference in New Issue
Block a user