mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[CK] refactoring according to the review feedback
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
#ifndef CK_ARRAY_HPP
|
||||
#define CK_ARRAY_HPP
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include <type_traits>
|
||||
#include <cassert>
|
||||
#include "type.hpp"
|
||||
@@ -44,8 +42,7 @@ struct Array
|
||||
{
|
||||
static_assert(T::Size() == Size(), "wrong! size not the same");
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < NSize; i++)
|
||||
for(index_t i = 0; i < NSize; i++)
|
||||
{
|
||||
mData[i] = a[i];
|
||||
}
|
||||
|
||||
@@ -375,6 +375,19 @@ struct sequence_reverse_inclusive_scan_impl;
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
struct sequence_reverse_inclusive_scan_impl<Sequence<Is...>, Reduce, Init>
|
||||
{
|
||||
template <index_t Size>
|
||||
static constexpr Array<index_t, Size> compute_array()
|
||||
{
|
||||
Array<index_t, Size> values = {Is...};
|
||||
Array<index_t, Size> result = {0};
|
||||
result.At(Size - 1) = Reduce{}(values[Size - 1], Init);
|
||||
for(index_t i = Size - 1; i > 0; --i)
|
||||
{
|
||||
result.At(i - 1) = Reduce{}(values[i - 1], result[i]);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <index_t... Indices>
|
||||
static constexpr auto compute(Sequence<Indices...>)
|
||||
{
|
||||
@@ -384,18 +397,21 @@ struct sequence_reverse_inclusive_scan_impl<Sequence<Is...>, Reduce, Init>
|
||||
{
|
||||
return Sequence<>{};
|
||||
}
|
||||
else if constexpr(size == 1)
|
||||
{
|
||||
constexpr index_t values[1] = {Is...};
|
||||
return Sequence<Reduce{}(values[0], Init)>{};
|
||||
}
|
||||
else if constexpr(size == 2)
|
||||
{
|
||||
constexpr index_t values[2] = {Is...};
|
||||
constexpr index_t r1 = Reduce{}(values[1], Init);
|
||||
constexpr index_t r0 = Reduce{}(values[0], r1);
|
||||
return Sequence<r0, r1>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr Array<index_t, size> arr = []() {
|
||||
Array<index_t, size> values = {Is...};
|
||||
Array<index_t, size> result = {0};
|
||||
result.At(size - 1) = Reduce{}(values[size - 1], Init);
|
||||
for(index_t i = size - 1; i > 0; --i)
|
||||
{
|
||||
result.At(i - 1) = Reduce{}(values[i - 1], result[i]);
|
||||
}
|
||||
return result;
|
||||
}();
|
||||
constexpr Array<index_t, size> arr = compute_array<size>();
|
||||
return Sequence<arr[Indices]...>{};
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
// Test basic Sequence construction and properties
|
||||
// Test basic Array construction and properties
|
||||
TEST(Array, BasicConstruction)
|
||||
{
|
||||
using Arr = Array<index_t, 5>;
|
||||
|
||||
Reference in New Issue
Block a user