[rocm-libraries] ROCm/rocm-libraries#4355 (commit e7f6909)

[CK TILE] Refactor sequence_reverse_inclusive_scan

## Proposed changes

Refactor ck tile `sequence_reverse_inclusive_scan` from recursive to
for-loop.

Tracking issue: #4229

This pull request introduces a new lightweight array type,
`static_array`, and refactors the sequence utilities to use it for
improved constexpr support and simplicity. The changes also include
updates to the build system to add container-related tests.

**Core Library Improvements:**

* Added a new header `static_array.hpp` that defines the `static_array`
type, a constexpr-friendly array with basic accessors and no custom
constructors.
* Updated includes in `core.hpp` and `sequence.hpp` to import
`static_array`.
[[1]](diffhunk://#diff-14b406eccf59794051a16c0c9c1a7e11234324bfdd107a5bbe0f173cd25bcddcR44)
[[2]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76R7)

**Refactoring to Use `static_array`:**

* Refactored sequence utilities in `sequence.hpp` to use `static_array`
instead of the previously forward-declared `array` type, including in
histogram and array generation logic.
[[1]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76L1108-R1133)
[[2]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76L1130-R1146)
* Rewrote the implementation of `sequence_reverse_inclusive_scan` to use
`static_array` for intermediate storage, improving constexpr evaluation
and clarity.

**Build System and Testing:**

* Added a new test subdirectory for container tests and a GoogleTest
executable for `unit_sequence.cpp` to the CMake build configuration.
[[1]](diffhunk://#diff-5d35ff7555d3f0b438d45cde06b661eb1332cdbec66287ac7ec3c478d688aae5R5)
[[2]](diffhunk://#diff-1f54f0d2b431b7fc74f7b4ffb66e80c381c904c3383b1d27987467e3482d6d7aR1-R7)

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Cong Ma
2026-02-24 15:52:33 +00:00
committed by assistant-librarian[bot]
parent fc3180120e
commit 3af1a0aafc
6 changed files with 772 additions and 18 deletions

View File

@@ -39,6 +39,7 @@
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/static_array.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/static_array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
@@ -35,6 +36,7 @@ template <typename Seq>
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq);
namespace impl {
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
template <index_t I, typename... Ts>
using at_index_t = __type_pack_element<I, Ts...>;
@@ -331,30 +333,66 @@ struct uniform_sequence_gen
using type = typename sequence_gen<NSize, F>::type;
};
// reverse inclusive scan (with init) sequence
template <typename, typename, index_t>
struct sequence_reverse_inclusive_scan;
// inclusive scan (with init) sequence
namespace impl {
template <index_t I, index_t... Is, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<I, Is...>, Reduce, Init>
template <typename Seq, typename Reduce, index_t Init, bool Reverse>
struct sequence_inclusive_scan_impl;
template <index_t... Is, typename Reduce, index_t Init, bool Reverse>
struct sequence_inclusive_scan_impl<sequence<Is...>, Reduce, Init, Reverse>
{
using old_scan = typename sequence_reverse_inclusive_scan<sequence<Is...>, Reduce, Init>::type;
template <index_t... Indices>
static constexpr auto compute(sequence<Indices...>)
{
constexpr index_t size = sizeof...(Is);
if constexpr(size == 0)
{
return sequence<>{};
}
else
{
constexpr auto arr = []() {
static_array<index_t, size> values = {Is...};
static_array<index_t, size> result = {0};
if constexpr(Reverse)
{
// Reverse scan: right to left
result[size - 1] = Reduce{}(values[size - 1], Init);
for(index_t i = size - 1; i > 0; --i)
{
result[i - 1] = Reduce{}(values[i - 1], result[i]);
}
}
else
{
// Forward scan: left to right
result[0] = Reduce{}(values[0], Init);
for(index_t i = 1; i < size; ++i)
{
result[i] = Reduce{}(values[i], result[i - 1]);
}
}
return result;
}();
return sequence<arr[Indices]...>{};
}
}
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.front());
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
};
} // namespace impl
using type = typename sequence_merge<sequence<new_reduce>, old_scan>::type;
template <typename Seq, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan
{
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, true>::type;
};
template <index_t I, typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<I>, Reduce, Init>
template <typename Seq, typename Reduce, index_t Init>
struct sequence_inclusive_scan
{
using type = sequence<Reduce{}(I, Init)>;
};
template <typename Reduce, index_t Init>
struct sequence_reverse_inclusive_scan<sequence<>, Reduce, Init>
{
using type = sequence<>;
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, false>::type;
};
// split sequence
@@ -880,7 +918,7 @@ CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce,
template <typename Seq, typename Reduce, index_t Init>
CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number<Init>)
{
return reverse_inclusive_scan_sequence(Seq{}.reverse(), Reduce{}, number<Init>{}).reverse();
return typename sequence_inclusive_scan<Seq, Reduce, Init>::type{};
}
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add

View File

@@ -0,0 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/numeric/integer.hpp"
namespace ck_tile {
// Fixed-size array with aggregate initialization
//
// This is a minimal array type designed for:
// - Constexpr/compile-time computation
// - GPU kernel code (trivially copyable)
// - Template metaprogramming
//
// Unlike ck_tile::array, this has no custom constructors,
// making it a literal type suitable for constexpr contexts.
// Use aggregate initialization: static_array<int, 3> arr{1, 2, 3};
template <typename T, index_t N>
struct static_array
{
// Public aggregate initialization makes this a literal type
T elems[N];
// Basic constexpr accessors
constexpr const T& operator[](index_t i) const { return elems[i]; }
constexpr T& operator[](index_t i) { return elems[i]; }
constexpr static index_t size() { return N; }
};
} // namespace ck_tile