From 0d8b3d36b2595af46b115ecb05f761e4b954ae24 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 14 Oct 2024 11:56:45 -0500 Subject: [PATCH] Add custom type vector support (#1333) * Add non_native_vector_type * Add a test * Add non-native vector type * Fix CTOR * Fix non-native vector type of 1 * Fix CTORs * Use vector_type to cover non-native implementation as well * Update the test * Format * Format * Fix copyright years * Remove BoolVecT so far * Add AsType test cases * Update assert error message * Remove redundant type * Update naming * Add complex half type with tests * Add tests for vector reshaping * Add missing alignas * Update test/data_type/test_custom_type.cpp Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> * Compare custom types to built-in types * Add default constructor test * Add an alignment test --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Po Yen Chen [ROCm/composable_kernel commit: 4cf70b36c1330b3ee25e00473b219857575d3df2] --- include/ck/utility/data_type.hpp | 655 ++++++++++++++++++++- test/data_type/CMakeLists.txt | 5 + test/data_type/test_custom_type.cpp | 874 ++++++++++++++++++++++++++++ 3 files changed, 1504 insertions(+), 30 deletions(-) create mode 100644 test/data_type/test_custom_type.cpp diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 4df14c6211..debeb472ad 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -13,8 +13,24 @@ using int4_t = _BitInt(4); using f8_t = _BitInt(8); using bf8_t = unsigned _BitInt(8); +inline constexpr auto next_pow2(uint32_t x) +{ + // Precondition: x > 1. + return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; +} + +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool +template +inline constexpr bool is_native_type() +{ + return is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value; +} + // vector_type -template +template struct vector_type; // Caution: DO NOT REMOVE @@ -171,7 +187,7 @@ struct scalar_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; using type = d1_t; @@ -189,7 +205,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } @@ -197,7 +214,8 @@ struct vector_type template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value, "wrong!"); + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); return data_.d1x1_; } @@ -205,7 +223,7 @@ struct vector_type __device__ int static err = 0; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -226,7 +244,8 @@ struct vector_type template __host__ __device__ constexpr const auto& AsType() const { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -245,7 +264,8 @@ struct vector_type template __host__ __device__ constexpr auto& AsType() { - static_assert(is_same::value || is_same::value, "wrong!"); + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -263,7 +283,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -287,7 +307,7 @@ struct vector_type __host__ __device__ constexpr const auto& AsType() const { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -311,7 +331,7 @@ struct vector_type __host__ __device__ constexpr auto& AsType() { static_assert(is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -333,7 +353,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -360,7 +380,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -389,7 +409,7 @@ struct vector_type { static_assert(is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -415,7 +435,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -445,7 +465,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -479,7 +499,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -509,7 +529,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -541,7 +561,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -579,7 +599,7 @@ struct vector_type static_assert(is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -613,7 +633,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -648,7 +668,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -691,7 +711,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -729,7 +749,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -766,7 +786,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -813,7 +833,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -855,7 +875,7 @@ struct vector_type }; template -struct vector_type +struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); @@ -894,7 +914,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -945,7 +965,7 @@ struct vector_type is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value, - "wrong!"); + "Something went wrong, please check src and dst types."); if constexpr(is_same::value) { @@ -990,6 +1010,581 @@ struct vector_type } }; +template +struct non_native_vector_base +{ + using type = non_native_vector_base; + + __host__ __device__ non_native_vector_base() = default; + __host__ __device__ non_native_vector_base(const type&) = default; + __host__ __device__ non_native_vector_base(type&&) = default; + __host__ __device__ ~non_native_vector_base() = default; + + T d[N]; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using type = d1_t; + + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + + using type = d2_t; + + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + + using type = d4_t; + + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + using int64_t = long; // fp64 @@ -1051,8 +1646,8 @@ using bf8x8_t = typename vector_type::type; using bf8x16_t = typename vector_type::type; using bf8x32_t = typename vector_type::type; using bf8x64_t = typename vector_type::type; + // u8 -// i8 using uint8x2_t = typename vector_type::type; using uint8x4_t = typename vector_type::type; using uint8x8_t = typename vector_type::type; diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index 95f1367fbf..a783be7bb0 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -18,4 +18,9 @@ if(result EQUAL 0) target_link_libraries(test_bf8 PRIVATE utility) endif() +add_gtest_executable(test_custom_type test_custom_type.cpp) +if(result EQUAL 0) + target_link_libraries(test_custom_type PRIVATE utility) +endif() + add_gtest_executable(test_type_convert_const type_convert_const.cpp) diff --git a/test/data_type/test_custom_type.cpp b/test/data_type/test_custom_type.cpp new file mode 100644 index 0000000000..1016812544 --- /dev/null +++ b/test/data_type/test_custom_type.cpp @@ -0,0 +1,874 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bf8_t; +using ck::bhalf_t; +using ck::f8_t; +using ck::half_t; +using ck::Number; +using ck::type_convert; +using ck::vector_type; + +TEST(Custom_bool, TestSize) +{ + struct custom_bool_t + { + bool data; + }; + ASSERT_EQ(sizeof(custom_bool_t), sizeof(bool)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bool, TestAsType) +{ + struct custom_bool_t + { + using type = bool; + type data; + custom_bool_t() : data{type{}} {} + custom_bool_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {false, true, false, true}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, false); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bool_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bool, TestAsTypeReshape) +{ + struct custom_bool_t + { + using type = bool; + type data; + custom_bool_t() : data{type{}} {} + custom_bool_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {false, true, false, true}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, false); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bool_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_int8, TestSize) +{ + struct custom_int8_t + { + int8_t data; + }; + ASSERT_EQ(sizeof(custom_int8_t), sizeof(int8_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_int8, TestAsType) +{ + struct custom_int8_t + { + using type = int8_t; + type data; + custom_int8_t() : data{type{}} {} + custom_int8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, -6, 8, -2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_int8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_int8, TestAsTypeReshape) +{ + struct custom_int8_t + { + using type = int8_t; + type data; + custom_int8_t() : data{type{}} {} + custom_int8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, -6, 8, -2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_int8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_uint8, TestSize) +{ + struct custom_uint8_t + { + uint8_t data; + }; + ASSERT_EQ(sizeof(custom_uint8_t), sizeof(uint8_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_uint8, TestAsType) +{ + struct custom_uint8_t + { + using type = uint8_t; + type data; + custom_uint8_t() : data{type{}} {} + custom_uint8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, 6, 8, 2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_uint8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_uint8, TestAsTypeReshape) +{ + struct custom_uint8_t + { + using type = uint8_t; + type data; + custom_uint8_t() : data{type{}} {} + custom_uint8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {3, 6, 8, 2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_uint8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_f8, TestSize) +{ + struct custom_f8_t + { + _BitInt(8) data; + }; + ASSERT_EQ(sizeof(custom_f8_t), sizeof(_BitInt(8))); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 2>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 4>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 8>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 16>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 32>)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type<_BitInt(8), 64>)); +} + +TEST(Custom_f8, TestAsType) +{ + struct custom_f8_t + { + using type = _BitInt(8); + type data; + custom_f8_t() : data{type{}} {} + custom_f8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f), + type_convert<_BitInt(8)>(-0.6f), + type_convert<_BitInt(8)>(0.8f), + type_convert<_BitInt(8)>(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_f8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_f8, TestAsTypeReshape) +{ + struct custom_f8_t + { + using type = _BitInt(8); + type data; + custom_f8_t() : data{type{}} {} + custom_f8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f), + type_convert<_BitInt(8)>(-0.6f), + type_convert<_BitInt(8)>(0.8f), + type_convert<_BitInt(8)>(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_f8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bf8, TestSize) +{ + struct custom_bf8_t + { + unsigned _BitInt(8) data; + }; + ASSERT_EQ(sizeof(custom_bf8_t), sizeof(unsigned _BitInt(8))); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bf8, TestAsType) +{ + struct custom_bf8_t + { + using type = unsigned _BitInt(8); + type data; + custom_bf8_t() : data{type{}} {} + custom_bf8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bf8_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bf8, TestAsTypeReshape) +{ + struct custom_bf8_t + { + using type = unsigned _BitInt(8); + type data; + custom_bf8_t() : data{type{}} {} + custom_bf8_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}( + [&](auto i) { ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0); }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bf8_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_half, TestSize) +{ + struct custom_half_t + { + half_t data; + }; + ASSERT_EQ(sizeof(custom_half_t), sizeof(half_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_half, TestAsType) +{ + struct custom_half_t + { + using type = half_t; + type data; + custom_half_t() : data{type{}} {} + custom_half_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_half_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_half, TestAsTypeReshape) +{ + struct custom_half_t + { + using type = half_t; + type data; + custom_half_t() : data{type{}} {} + custom_half_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_half_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bhalf, TestSize) +{ + struct custom_bhalf_t + { + bhalf_t data; + }; + ASSERT_EQ(sizeof(custom_bhalf_t), sizeof(bhalf_t)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_bhalf, TestAsType) +{ + struct custom_bhalf_t + { + using type = bhalf_t; + type data; + custom_bhalf_t() : data{type{}} {} + custom_bhalf_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bhalf_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_bhalf, TestAsTypeReshape) +{ + struct custom_bhalf_t + { + using type = bhalf_t; + type data; + custom_bhalf_t() : data{type{}} {} + custom_bhalf_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {type_convert(0.3f), + type_convert(-0.6f), + type_convert(0.8f), + type_convert(-0.2f)}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_bhalf_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_float, TestSize) +{ + struct custom_float_t + { + float data; + }; + ASSERT_EQ(sizeof(custom_float_t), sizeof(float)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_float, TestAsType) +{ + struct custom_float_t + { + using type = float; + type data; + custom_float_t() : data{type{}} {} + custom_float_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3f, -0.6f, 0.8f, -0.2f}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0f); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_float_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_float, TestAsTypeReshape) +{ + struct custom_float_t + { + using type = float; + type data; + custom_float_t() : data{type{}} {} + custom_float_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3f, -0.6f, 0.8f, -0.2f}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0f); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_float_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_double, TestSize) +{ + struct custom_double_t + { + double data; + }; + ASSERT_EQ(sizeof(custom_double_t), sizeof(double)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), sizeof(vector_type)); +} + +TEST(Custom_double, TestAsType) +{ + struct custom_double_t + { + using type = double; + type data; + custom_double_t() : data{type{}} {} + custom_double_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3, 0.6, 0.8, 0.2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_double_t{test_vec.at(i)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Custom_double, TestAsTypeReshape) +{ + struct custom_double_t + { + using type = double; + type data; + custom_double_t() : data{type{}} {} + custom_double_t(type init) : data{init} {} + }; + + // test size + const int size = 4; + std::vector test_vec = {0.3, 0.6, 0.8, 0.2}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).data, 0.0); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = custom_double_t{test_vec.at(i)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).data, test_vec.at(i)); + }); +} + +TEST(Complex_half, TestSize) +{ + struct complex_half_t + { + half_t real; + half_t img; + }; + ASSERT_EQ(sizeof(complex_half_t), sizeof(half_t) + sizeof(half_t)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); + ASSERT_EQ(sizeof(vector_type), + sizeof(vector_type) + sizeof(vector_type)); +} + +TEST(Complex_half, TestAlignment) +{ + struct complex_half_t + { + half_t real; + half_t img; + }; + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); + ASSERT_EQ(alignof(vector_type), + alignof(vector_type) + alignof(vector_type)); +} + +TEST(Complex_half, TestAsType) +{ + struct complex_half_t + { + using type = half_t; + type real; + type img; + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + + // test size + const int size = 4; + // custom type number of elements + const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type); + std::vector test_vec = {half_t{0.3f}, + half_t{-0.6f}, + half_t{0.8f}, + half_t{-0.2f}, + half_t{0.5f}, + half_t{-0.7f}, + half_t{0.9f}, + half_t{-0.3f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).real, + type_convert(0.0f)); + ASSERT_EQ(right_vec.template AsType()(Number{}).img, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)}; + }); + // copy the vector + vector_type left_vec{right_vec}; + // check if values were copied correctly + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).real, + test_vec.at(num_elem * i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).img, + test_vec.at(num_elem * i + 1)); + }); +} + +TEST(Complex_half, TestAsTypeReshape) +{ + struct complex_half_t + { + using type = half_t; + type real; + type img; + complex_half_t() : real{type{}}, img{type{}} {} + complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {} + }; + + // test size + const int size = 4; + // custom type number of elements + const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type); + std::vector test_vec = {half_t{0.3f}, + half_t{-0.6f}, + half_t{0.8f}, + half_t{-0.2f}, + half_t{0.5f}, + half_t{-0.7f}, + half_t{0.9f}, + half_t{-0.3f}}; + // reference vector + vector_type right_vec; + // check default CTOR + ck::static_for<0, size, 1>{}([&](auto i) { + ASSERT_EQ(right_vec.template AsType()(Number{}).real, + type_convert(0.0f)); + ASSERT_EQ(right_vec.template AsType()(Number{}).img, + type_convert(0.0f)); + }); + // assign test values to the vector + ck::static_for<0, size, 1>{}([&](auto i) { + right_vec.template AsType()(Number{}) = + complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)}; + }); + // copy the first half of a vector + vector_type left_vec{ + right_vec.template AsType::type>()(Number<0>{})}; + // check if values were copied correctly + ck::static_for<0, size / 2, 1>{}([&](auto i) { + ASSERT_EQ(left_vec.template AsType()(Number{}).real, + test_vec.at(num_elem * i)); + ASSERT_EQ(left_vec.template AsType()(Number{}).img, + test_vec.at(num_elem * i + 1)); + }); +}