mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-14 20:27:42 +00:00
[CK] Optimize multi-dimensional static for loop decomposition (#4447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Recursive template implementations might initially seem attractive to minimize necessary coding. Unfortunately, this style is often affects readability and requires significant resources from the compiler to generate instantiation chains. In "high-traffic" code (e.g., used in many places + compilation units), this generally does not scale well and can bloat the overall compile times to unnecessary lengths. The aim of this PR is to take some of most high-traffic utility code and try our best to eliminate recursive templates in favor of fold expansions and constexpr function helpers. In local tests with clang build analyzer, device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_16x16_instance.cpp showed high hit-rates on slow template instantiations in static_for, dimensional static_for (static_ford), which are subsequently affected by implementation of the Sequence class and associated transforms. Example: **** Templates that took longest to instantiate: 70111 ms: ck::detail::applier<int, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1... (372 times, avg 188 ms) // **70 seconds!** The above is part of the implementation of static_for which uses Sequence classes.. ## Technical Details ### Summary of Optimization Techniques | Technique | Used In | Benefit | |-----------|---------|---------| | __Constexpr for-loop computation__ | sequence_reverse_inclusive_scan, sequence_map_inverse | Moves O(N) work from template instantiation to constexpr evaluation | | __Pack expansion with indexing__ | sequence_reverse, Sequence::Modify | Single template instantiation instead of recursive | | __Flat iteration + decomposition__ | ford, static_ford | O(1) template depth instead of O(N^D) | | __Pre-computed strides__ | index_decomposer | Enables O(1) linear-to-multi-index conversion | ### Impact on Compile Time These optimizations reduce template instantiation depth from O(N) or O(N^D) to O(1), which: 1. Reduces compiler memory usage 2. Reduces compile time exponentially for deep instantiation chains 3. Enables larger iteration spaces without hitting template depth limits ## Test Plan * Existing tests for Sequence are re-used to affirm correctness * Unit tests for ford and static_ford are added (dimensional looping) * 8 new regression tests specifically verify the fixes for the PR feedback: - `NonTrivialOrder3D_201` - Tests Orders<2,0,1> for static_ford - `NonTrivialOrder3D_201_Runtime` - Tests Orders<2,0,1> for ford - `ConsistencyWithNonTrivialOrder_201` - Verifies static_ford and ford consistency - `NonTrivialOrder3D_120` - Tests Orders<1,2,0> for static_ford - `NonTrivialOrder3D_120_Runtime` - Tests Orders<1,2,0> for ford - `NonTrivialOrder4D` - Tests 4D with Orders<3,1,0,2> for static_ford - `NonTrivialOrder4D_Runtime` - Tests 4D with Orders<3,1,0,2> for ford - `AsymmetricDimensionsWithOrder` - Tests asymmetric dimensions with non-trivial ordering ## Test Result ### Compile Time Comparison: `8b72bc8` (base) → `477e0686` (optimized) #### Commits in Range (8 commits) 1. `fd4ca17f48` - Optimize sequence_reverse_inclusive_scan and sequence_reverse 2. `7a7e3fdeef` - Optimize sequence_map_inverse 3. `92855c9913` - Optimize ford and static_ford calls to eliminate nested template recursion 4. `88a564032b` - Add unit tests for ford and static_ford 5. `1a0fb22217` - Fix clang-format 6. `8a0d26bddf` - Increase template recursion depth to 1024 7. `dc53bb6e20` - Address copilot feedback and add regression tests 8. `477e06861d` - Increase bracket depth to 1024 #### Build Timing Results | File | Base (8b72bc8759d9 | HEAD(a0438bd398) | Improvement | |------|------|------|-------------| | grouped_conv2d_fwd (f16) -j1 | 313.31s | 272.93s | __12.9% faster__ | | grouped_conv1d_fwd (bf16) -j1 | 79.33s | 68.61s | __13.5% faster__ | | grouped_conv1d_bwd_weight (f16) -j1| 15.77s | 14.31s | __9.2% faster__ | | device_grouped_conv2d_fwd_instance -j64 | s | s | __% faster__ | #### Key Optimizations 1. __sequence_reverse_inclusive_scan/sequence_reverse__: O(N) → O(1) template depth 2. __sequence_map_inverse__: O(N) → O(1) template depth 3. __ford/static_ford__: O(N^D) → O(1) template depth using flat iteration with index decomposition 4. __Copilot feedback fixes__: Corrected New2Old mapping for non-trivial orderings ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
243 lines
8.5 KiB
C++
243 lines
8.5 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/utility/functional.hpp"
|
|
#include "ck/utility/functional2.hpp"
|
|
#include "ck/utility/sequence.hpp"
|
|
#include "ck/utility/multi_index.hpp"
|
|
#include "ck/utility/math.hpp"
|
|
|
|
namespace ck {
|
|
|
|
namespace detail {
|
|
|
|
/**
|
|
* @brief Common base class for static_ford and ford.
|
|
*
|
|
* Provides shared compile-time constants and type aliases used by both
|
|
* static_ford (compile-time iteration) and ford (runtime iteration).
|
|
*
|
|
* @tparam Lengths Sequence<L0, L1, ...> specifying the size of each dimension
|
|
* @tparam Orders Sequence<O0, O1, ...> specifying the iteration order of dimensions.
|
|
* Orders[i] indicates which dimension is iterated at loop level i.
|
|
*/
|
|
template <class Lengths, class Orders>
|
|
struct ford_base
|
|
{
|
|
/// Number of dimensions
|
|
static constexpr index_t NDim = Lengths::Size();
|
|
|
|
/// Total number of iterations (product of all lengths)
|
|
static constexpr index_t TotalSize =
|
|
reduce_on_sequence(Lengths{}, math::multiplies{}, Number<1>{});
|
|
|
|
/// Lengths reordered according to iteration order
|
|
static constexpr auto OrderedLengths = Lengths::ReorderGivenNew2Old(Orders{});
|
|
|
|
/// Type of OrderedLengths with cv-qualifiers removed
|
|
using OrderedLengthsType = remove_cvref_t<decltype(OrderedLengths)>;
|
|
|
|
/// Mapping from loop level ("new" index) to original dimension ("old" index)
|
|
using New2Old = Orders;
|
|
|
|
__host__ __device__ constexpr ford_base()
|
|
{
|
|
static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty");
|
|
static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size");
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @brief Helper for decomposing a linear index into multi-dimensional indices.
|
|
*
|
|
* Computes strides at compile time and provides both compile-time and runtime
|
|
* index decomposition. Used by static_ford and ford to convert a flat iteration
|
|
* index into N-dimensional coordinates.
|
|
*
|
|
* For OrderedLengths = Sequence<L0, L1, L2>:
|
|
* - strides = {L1*L2, L2, 1}
|
|
* - ordered_idx[i] = (linear_idx / strides[i]) % lengths[i]
|
|
*
|
|
* @tparam OrderedLengths Sequence<...> of dimension sizes in iteration order
|
|
* @tparam IndexSeq Sequence<0, 1, ..., NDim-1> for pack expansion
|
|
*/
|
|
template <class OrderedLengths, class IndexSeq>
|
|
struct index_decomposer;
|
|
|
|
template <index_t... Ls, index_t... Is>
|
|
struct index_decomposer<Sequence<Ls...>, Sequence<Is...>>
|
|
{
|
|
/// Number of dimensions
|
|
static constexpr index_t NDim = sizeof...(Ls);
|
|
|
|
/// Dimension lengths in iteration order
|
|
static constexpr index_array<NDim> lengths = {{Ls...}};
|
|
|
|
/**
|
|
* @brief Compute all strides in a single O(N) pass.
|
|
*
|
|
* For dimensions with lengths [L0, L1, L2, ...]:
|
|
* strides[N-1] = 1
|
|
* strides[i] = strides[i+1] * lengths[i+1]
|
|
*
|
|
* @return index_array containing computed strides
|
|
*/
|
|
static constexpr index_array<NDim> compute_all_strides()
|
|
{
|
|
index_array<NDim> result{};
|
|
if constexpr(NDim > 0)
|
|
{
|
|
result[NDim - 1] = 1;
|
|
for(index_t i = NDim - 2; i >= 0; --i)
|
|
{
|
|
result[i] = result[i + 1] * lengths[i + 1];
|
|
}
|
|
}
|
|
return result;
|
|
}
|
|
|
|
/// Pre-computed strides for each dimension
|
|
static constexpr index_array<NDim> strides = compute_all_strides();
|
|
|
|
/**
|
|
* @brief Compile-time decomposition of a linear index.
|
|
*
|
|
* Returns a Sequence containing the multi-dimensional indices
|
|
* in iteration order.
|
|
*
|
|
* @tparam LinearIdx The linear index to decompose (compile-time constant)
|
|
*/
|
|
template <index_t LinearIdx>
|
|
using decompose = Sequence<((LinearIdx / strides[Is]) % lengths[Is])...>;
|
|
|
|
/**
|
|
* @brief Runtime decomposition of a linear index with reordering.
|
|
*
|
|
* Decomposes linear_idx into ordered indices, then reorders them
|
|
* to the original dimension order and stores in result.
|
|
*
|
|
* @tparam New2Old Sequence mapping iteration position to original dimension
|
|
* @tparam MultiIndex Type of the output multi-index container
|
|
* @param linear_idx The linear index to decompose
|
|
* @param[out] result Multi-index container to store the result
|
|
*/
|
|
template <class New2Old, class MultiIndex>
|
|
__host__ __device__ static void decompose_runtime(index_t linear_idx, MultiIndex& result)
|
|
{
|
|
// Compute ordered indices and assign to result in original dimension order
|
|
((result(Number<New2Old::At(Number<Is>{})>{}) = (linear_idx / strides[Is]) % lengths[Is]),
|
|
...);
|
|
}
|
|
};
|
|
|
|
} // namespace detail
|
|
|
|
/**
|
|
* @brief Compile-time N-dimensional loop with static multi-indices.
|
|
*
|
|
* Iterates over an N-dimensional space where dimensions have sizes specified
|
|
* by Lengths. The iteration order is controlled by Orders. Each iteration
|
|
* provides a compile-time Sequence containing the current multi-index.
|
|
*
|
|
* Uses O(1) template instantiation depth via flat loop with index decomposition,
|
|
* avoiding recursive template structures.
|
|
*
|
|
* Example:
|
|
* @code
|
|
* // Iterate over 2x3 space in row-major order (dim 0 outer, dim 1 inner)
|
|
* static_ford<Sequence<2, 3>>{}([](auto multi_id) {
|
|
* constexpr index_t i = multi_id[Number<0>{}]; // 0, 0, 0, 1, 1, 1
|
|
* constexpr index_t j = multi_id[Number<1>{}]; // 0, 1, 2, 0, 1, 2
|
|
* });
|
|
*
|
|
* // Column-major order (dim 1 outer, dim 0 inner)
|
|
* static_ford<Sequence<2, 3>, Sequence<1, 0>>{}([](auto multi_id) {
|
|
* // Visits: (0,0), (1,0), (0,1), (1,1), (0,2), (1,2)
|
|
* });
|
|
* @endcode
|
|
*
|
|
* @tparam Lengths Sequence<L0, L1, ...> specifying dimension sizes
|
|
* @tparam Orders Sequence<O0, O1, ...> specifying iteration order
|
|
* (default: Sequence<0, 1, ..., N-1> for row-major)
|
|
*/
|
|
template <class Lengths, class Orders = make_index_sequence<Lengths::GetSize()>>
|
|
struct static_ford : detail::ford_base<Lengths, Orders>
|
|
{
|
|
using Base = detail::ford_base<Lengths, Orders>;
|
|
using Decomposer = detail::index_decomposer<typename Base::OrderedLengthsType,
|
|
make_index_sequence<Base::NDim>>;
|
|
|
|
/**
|
|
* @brief Execute the N-dimensional loop.
|
|
*
|
|
* Calls f with a compile-time Sequence<i0, i1, ...> for each point
|
|
* in the iteration space.
|
|
*
|
|
* @tparam F Functor type with signature F(Sequence<...>)
|
|
* @param f The functor to call for each multi-index
|
|
*/
|
|
template <class F>
|
|
__host__ __device__ constexpr void operator()(F f) const
|
|
{
|
|
static_for<0, Base::TotalSize, 1>{}([&](auto linear_idx) {
|
|
using OrderedIdx = typename Decomposer::template decompose<linear_idx.value>;
|
|
f(OrderedIdx::ReorderGivenOld2New(Orders{}));
|
|
});
|
|
}
|
|
};
|
|
|
|
/**
|
|
* @brief Runtime N-dimensional loop with runtime multi-indices.
|
|
*
|
|
* Iterates over an N-dimensional space where dimensions have sizes specified
|
|
* by Lengths. The iteration order is controlled by Orders. Each iteration
|
|
* provides a runtime multi-index container.
|
|
*
|
|
* Uses O(1) template instantiation depth via flat for-loop with index decomposition,
|
|
* avoiding recursive template structures.
|
|
*
|
|
* Example:
|
|
* @code
|
|
* // Iterate over 2x3 space in row-major order
|
|
* ford<Sequence<2, 3>>{}([](auto multi_id) {
|
|
* index_t i = multi_id[Number<0>{}]; // Runtime values
|
|
* index_t j = multi_id[Number<1>{}];
|
|
* });
|
|
* @endcode
|
|
*
|
|
* @tparam Lengths Sequence<L0, L1, ...> specifying dimension sizes
|
|
* @tparam Orders Sequence<O0, O1, ...> specifying iteration order
|
|
* (default: Sequence<0, 1, ..., N-1> for row-major)
|
|
*/
|
|
template <class Lengths, class Orders = make_index_sequence<Lengths::GetSize()>>
|
|
struct ford : detail::ford_base<Lengths, Orders>
|
|
{
|
|
using Base = detail::ford_base<Lengths, Orders>;
|
|
using Decomposer = detail::index_decomposer<typename Base::OrderedLengthsType,
|
|
make_index_sequence<Base::NDim>>;
|
|
|
|
/**
|
|
* @brief Execute the N-dimensional loop.
|
|
*
|
|
* Calls f with a runtime multi-index for each point in the iteration space.
|
|
*
|
|
* @tparam F Functor type with signature F(MultiIndex)
|
|
* @param f The functor to call for each multi-index
|
|
*/
|
|
template <class F>
|
|
__host__ __device__ constexpr void operator()(F f) const
|
|
{
|
|
for(index_t linear_idx = 0; linear_idx < Base::TotalSize; ++linear_idx)
|
|
{
|
|
auto multi_id = make_zero_multi_index<Base::NDim>();
|
|
Decomposer::template decompose_runtime<Orders>(linear_idx, multi_id);
|
|
f(multi_id);
|
|
}
|
|
}
|
|
};
|
|
|
|
} // namespace ck
|