mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4355 (commit e7f6909)
[CK TILE] Refactor sequence_reverse_inclusive_scan ## Proposed changes Refactor ck tile `sequence_reverse_inclusive_scan` from recursive to for-loop. Tracking issue: #4229 This pull request introduces a new lightweight array type, `static_array`, and refactors the sequence utilities to use it for improved constexpr support and simplicity. The changes also include updates to the build system to add container-related tests. **Core Library Improvements:** * Added a new header `static_array.hpp` that defines the `static_array` type, a constexpr-friendly array with basic accessors and no custom constructors. * Updated includes in `core.hpp` and `sequence.hpp` to import `static_array`. [[1]](diffhunk://#diff-14b406eccf59794051a16c0c9c1a7e11234324bfdd107a5bbe0f173cd25bcddcR44) [[2]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76R7) **Refactoring to Use `static_array`:** * Refactored sequence utilities in `sequence.hpp` to use `static_array` instead of the previously forward-declared `array` type, including in histogram and array generation logic. [[1]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76L1108-R1133) [[2]](diffhunk://#diff-5042e5b47bb2ba78bbab2d284338cf0503bc8fb76a7d631cc2684ad6ca832a76L1130-R1146) * Rewrote the implementation of `sequence_reverse_inclusive_scan` to use `static_array` for intermediate storage, improving constexpr evaluation and clarity. **Build System and Testing:** * Added a new test subdirectory for container tests and a GoogleTest executable for `unit_sequence.cpp` to the CMake build configuration. [[1]](diffhunk://#diff-5d35ff7555d3f0b438d45cde06b661eb1332cdbec66287ac7ec3c478d688aae5R5) [[2]](diffhunk://#diff-1f54f0d2b431b7fc74f7b4ffb66e80c381c904c3383b1d27987467e3482d6d7aR1-R7) Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fc3180120e
commit
3af1a0aafc
@@ -39,6 +39,7 @@
|
||||
#include "ck_tile/core/container/multi_index.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
#include "ck_tile/core/container/static_array.hpp"
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/container/static_array.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
@@ -35,6 +36,7 @@ template <typename Seq>
|
||||
CK_TILE_HOST_DEVICE constexpr auto sequence_pop_back(Seq);
|
||||
|
||||
namespace impl {
|
||||
|
||||
// static_assert(__has_builtin(__type_pack_element), "can't find __type_pack_element");
|
||||
template <index_t I, typename... Ts>
|
||||
using at_index_t = __type_pack_element<I, Ts...>;
|
||||
@@ -331,30 +333,66 @@ struct uniform_sequence_gen
|
||||
using type = typename sequence_gen<NSize, F>::type;
|
||||
};
|
||||
|
||||
// reverse inclusive scan (with init) sequence
|
||||
template <typename, typename, index_t>
|
||||
struct sequence_reverse_inclusive_scan;
|
||||
// inclusive scan (with init) sequence
|
||||
namespace impl {
|
||||
|
||||
template <index_t I, index_t... Is, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<sequence<I, Is...>, Reduce, Init>
|
||||
template <typename Seq, typename Reduce, index_t Init, bool Reverse>
|
||||
struct sequence_inclusive_scan_impl;
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init, bool Reverse>
|
||||
struct sequence_inclusive_scan_impl<sequence<Is...>, Reduce, Init, Reverse>
|
||||
{
|
||||
using old_scan = typename sequence_reverse_inclusive_scan<sequence<Is...>, Reduce, Init>::type;
|
||||
template <index_t... Indices>
|
||||
static constexpr auto compute(sequence<Indices...>)
|
||||
{
|
||||
constexpr index_t size = sizeof...(Is);
|
||||
if constexpr(size == 0)
|
||||
{
|
||||
return sequence<>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto arr = []() {
|
||||
static_array<index_t, size> values = {Is...};
|
||||
static_array<index_t, size> result = {0};
|
||||
if constexpr(Reverse)
|
||||
{
|
||||
// Reverse scan: right to left
|
||||
result[size - 1] = Reduce{}(values[size - 1], Init);
|
||||
for(index_t i = size - 1; i > 0; --i)
|
||||
{
|
||||
result[i - 1] = Reduce{}(values[i - 1], result[i]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Forward scan: left to right
|
||||
result[0] = Reduce{}(values[0], Init);
|
||||
for(index_t i = 1; i < size; ++i)
|
||||
{
|
||||
result[i] = Reduce{}(values[i], result[i - 1]);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}();
|
||||
return sequence<arr[Indices]...>{};
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t new_reduce = Reduce{}(I, old_scan{}.front());
|
||||
using type = decltype(compute(make_index_sequence<sizeof...(Is)>{}));
|
||||
};
|
||||
} // namespace impl
|
||||
|
||||
using type = typename sequence_merge<sequence<new_reduce>, old_scan>::type;
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan
|
||||
{
|
||||
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, true>::type;
|
||||
};
|
||||
|
||||
template <index_t I, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<sequence<I>, Reduce, Init>
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
struct sequence_inclusive_scan
|
||||
{
|
||||
using type = sequence<Reduce{}(I, Init)>;
|
||||
};
|
||||
|
||||
template <typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan<sequence<>, Reduce, Init>
|
||||
{
|
||||
using type = sequence<>;
|
||||
using type = typename impl::sequence_inclusive_scan_impl<Seq, Reduce, Init, false>::type;
|
||||
};
|
||||
|
||||
// split sequence
|
||||
@@ -880,7 +918,7 @@ CK_TILE_HOST_DEVICE constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce,
|
||||
template <typename Seq, typename Reduce, index_t Init>
|
||||
CK_TILE_HOST_DEVICE constexpr auto inclusive_scan_sequence(Seq, Reduce, number<Init>)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(Seq{}.reverse(), Reduce{}, number<Init>{}).reverse();
|
||||
return typename sequence_inclusive_scan<Seq, Reduce, Init>::type{};
|
||||
}
|
||||
|
||||
// e.g. Seq<2, 3, 4> --> Seq<0, 2, 5>, Init=0, Reduce=Add
|
||||
|
||||
30
include/ck_tile/core/container/static_array.hpp
Normal file
30
include/ck_tile/core/container/static_array.hpp
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Fixed-size array with aggregate initialization
|
||||
//
|
||||
// This is a minimal array type designed for:
|
||||
// - Constexpr/compile-time computation
|
||||
// - GPU kernel code (trivially copyable)
|
||||
// - Template metaprogramming
|
||||
//
|
||||
// Unlike ck_tile::array, this has no custom constructors,
|
||||
// making it a literal type suitable for constexpr contexts.
|
||||
// Use aggregate initialization: static_array<int, 3> arr{1, 2, 3};
|
||||
template <typename T, index_t N>
|
||||
struct static_array
|
||||
{
|
||||
// Public aggregate initialization makes this a literal type
|
||||
T elems[N];
|
||||
|
||||
// Basic constexpr accessors
|
||||
constexpr const T& operator[](index_t i) const { return elems[i]; }
|
||||
constexpr T& operator[](index_t i) { return elems[i]; }
|
||||
|
||||
constexpr static index_t size() { return N; }
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user