Add tensor partition and generic copy for ck wrapper (#1108)

* Add tensor partition and generic copy for ck wrapper

* Update changelog

* Stylistic fixes

* Change shape/strides logic to descriptor transforms

* Fixes

* Fix client example

* Fix comments
This commit is contained in:
Bartłomiej Kocot
2024-01-03 01:10:57 +01:00
committed by GitHub
parent b268f273de
commit 4234b3a691
14 changed files with 940 additions and 306 deletions

View File

@@ -84,7 +84,8 @@ TEST_F(TestWrapperLayout, 2d)
ck::make_tuple(ck::Sequence<0>{}));
const auto layout_runtime = ck::wrapper::make_layout(ck::make_tuple(d1, d0));
const auto layout_compiletime =
ck::wrapper::make_layout(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}));
ck::wrapper::make_layout(ck::make_tuple(ck::Number<d1>{}, ck::Number<d0>{}),
ck::make_tuple(ck::Number<s1>{}, ck::Number<s0>{}));
std::vector<ck::Tuple<ck::index_t, ck::index_t>> idxs;
for(ck::index_t h = 0; h < d1; h++)
@@ -435,19 +436,11 @@ TEST(TestLayoutHelpers, ShapeAndStrides)
constexpr bool check_compiletime_shape =
std::is_same_v<decltype(shape_compiletime),
std::remove_reference_t<decltype(shape(layout_compiletime))>>;
constexpr bool check_compiletime_strides =
std::is_same_v<decltype(strides_compiletime),
std::remove_reference_t<decltype(stride(layout_compiletime))>>;
constexpr bool check_runtime_shape =
std::is_same_v<decltype(shape_runtime),
std::remove_reference_t<decltype(shape(layout_runtime))>>;
constexpr bool check_runtime_strides =
std::is_same_v<decltype(strides_runtime),
std::remove_reference_t<decltype(stride(layout_runtime))>>;
EXPECT_TRUE(check_compiletime_shape);
EXPECT_TRUE(check_compiletime_strides);
EXPECT_TRUE(check_runtime_shape);
EXPECT_TRUE(check_runtime_strides);
}
TEST(TestLayoutHelpers, Hierarchical)