From 6ed0dde669fa58b5a304cd56640fec19d010074f Mon Sep 17 00:00:00 2001 From: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Date: Wed, 11 Feb 2026 11:59:43 -0700 Subject: [PATCH] [CK] Optimize vector type build times (#4471) **Supercedes https://github.com/ROCm/rocm-libraries/pull/4281 due to CI issues on import** ## Proposed changes Build times can be affected by many different things and is highly attributed to the way we write and use the code. Two critical areas of the builds are **frontend parsing** and **backend codegen and compilation**. ### Frontend Parsing The length of the code, the include header tree and macro expansions all affect the front-end parsing time. This PR seeks to reduce the parsing time of the dtype_vector.hpp vector_type class by reducing redundant code by generalization. * Partial specializations of vector_type for native and non-native datatypes have been generalized to one single class, consolidating all of the data initialization and AsType casting requirements into one place. * The class nnvb_data_t_selector (e.g., Non-native vector base dataT selector) class has been removed and replaced with scalar_type instantiations as they have the same purpose. Scalar type class' purpose is already to map generalized datatypes to native types compatible with ext_vector_t. ### Backend Codegen Template instantiation behavior can also affect build times. Recursive instantiations are very slow versus concrete instantiations. The compiler must make multiple passes to expand template instantiations so we need to be careful about how they are used. * Previous vector_type classes declared a union storage class, which aliases StaticallyIndexedArray. ``` template struct vector_type()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); using type = d4_t; union { d4_t d4_; StaticallyIndexedArray d1x4_; StaticallyIndexedArray d2x2_; StaticallyIndexedArray d4x1_; } data_; ... }; ``` * Upon further inspection, StaticallyIndexedArray is built on-top of a recursive Tuple concatenation. ``` template struct StaticallyIndexedArrayImpl { using type = typename tuple_concat::type, typename StaticallyIndexedArrayImpl::type>::type; }; ``` This union storage has been removed from the vector_type storage class. * Further references to StaticallyIndexedArray have been replaced with StaticallyIndexedArray_v2, which is a concrete implementation using C-style arrays. ``` template struct StaticallyIndexedArray_v2 { ... T data_[N]; }; ``` ### Fixes * Using bool datatype with vector_type was previously error prone. Bool, as a native datatype would be stored into bool ext_vector_type(N) for storage, which is a packed datatype. Meaning that for example, sizeof(bool ext_vector_type(4)) == 1, which does not equal sizeof(StaticallyIndexedArray == 4. The union of these datatypes has incorrect data slicing, meaning that the bits location of the packed bool do not match with the StaticallyIndexedArray member. As such, vector_type will use C-Style array storage for bool type instead of ext_vector_type. ``` template using NativeVectorT = T __attribute__((ext_vector_type(Rank))); sizeof(NativeVectorT) == 1 (1 byte per 4 bool - packed) element0 = bit 0 of byte 0 element1 = bit 1 of byte 0 element2 = bit 2 of byte 0 element3 = bit 3 of byte 0 sizeof(StaticallyIndexedArray[NativeVectorT, 4] == 4 (1 byte per bool) element0 = bit 0 of byte 0 element1 = bit 0 of byte 1 element1 = bit 0 of byte 2 element1 = bit 0 of byte 3 union{ NativeVectorT d1_t; ... StaticallyIndexedArray[NativeVectorT, 4] d4x1; }; // union size == 4 which means invalid slicing! ``` * Math utilities such as next_power_of_two addressed for invalid cases of X < 2 * Remove redundant implementation of next_pow2 ### Additions * integer_log2_floor to math.hpp * is_power_of_two_integer to math.hpp ### Build Time Analysis Machine: banff-cyxtera-s78-2 Target: gfx942 | Build Target | Threads | Frontend Parse Time (s) | Backend Codegen Time (s) | TotalTime (s) | commitId | |---------------|---------|-------------------------|--------------------------|---------------| ---------------| | device_grouped_conv3d_fwd_bias_bnorm_clamp_instance | 1 | 1452 | 331 | 1783 | 2e08a7e (develop) | | device_grouped_conv3d_fwd_bias_bnorm_clamp_instance | 1 | 1403 | 332 | 1735 (-2.7%) | fad4235| ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: systems-assistant[bot] Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> --- include/ck/utility/data_type.hpp | 165 +- include/ck/utility/dtype_vector.hpp | 2282 ++++--------------------- include/ck/utility/dynamic_buffer.hpp | 5 +- include/ck/utility/math.hpp | 22 +- include/ck/utility/type_convert.hpp | 4 +- 5 files changed, 512 insertions(+), 1966 deletions(-) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 8e6f875c39..ff0bb10d0c 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -34,9 +34,48 @@ using f4_t = unsigned _BitInt(4); using f6_t = _BitInt(6); // e2m3 format using bf6_t = unsigned _BitInt(6); // e3m2 format -// scalar_type -template -struct scalar_type; +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool +template +inline constexpr bool is_native_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Wrapper for native vector type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using NativeVectorT = T __attribute__((ext_vector_type(Rank))); + +/** + * @brief Mapping of incoming type to local native vector storage type and vector size + * @tparam T Incoming data type + */ +template +struct scalar_type +{ + // Basic data type mapping to unsigned _BitInt of appropriate size + using type = unsigned _BitInt(8 * sizeof(T)); + static constexpr index_t vector_size = 1; +}; + +/** + * @brief scalar_type trait override for NativeVectorT + * @tparam T The vector type + * @tparam Rank The number of elements in the vector + */ +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = Rank; +}; struct f4x2_pk_t { @@ -74,6 +113,39 @@ struct f4x2_pk_t } }; +// TODO: Unfortunately, we cannot partially specialize scalar_type for vectors written +// in the following way: +// template +// struct scalar_type +// { +// using type = T; +// static constexpr index_t vector_size = Rank; +// }; +// The compiler errors out with "partial specialization is not allowed for this type", +// claiming that the Rank is not a deducible parameter. This might be a compiler bug. +// Note the above type is classified differently from the NativeVectorT alias, +// even though they are functionally equivalent and are trivially constructibe from each other. +// This is unfortunate, but we have to work around it because some LLVM builtins for some +// operations (e.g., mma) may return the former type. +// For now we have to explicitly specialize for each vector size we need. These are used +// in f6_pk_t below. + +/// @brief scalar_type trait override for uint32_t vector of size 3 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 3; +}; + +/// @brief scalar_type trait override for uint32_t vector of size 6 +template <> +struct scalar_type +{ + using type = uint32_t; + static constexpr index_t vector_size = 6; +}; + template struct f6_pk_t { @@ -89,28 +161,48 @@ struct f6_pk_t static constexpr index_t vector_size = (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units - using storage_type = element_type __attribute__((ext_vector_type(vector_size))); + using storage_type = NativeVectorT; storage_type data_{storage_type(0)}; // packed data using type = f6_pk_t; + /** This class may trivially constructed by the following vector type alias + * for example from a result of an mma operation. This is primarily for internal use. + * @note f6x16_pk_t and f6x32_pk_t storage types, may be trivially constructed from + * uint32_t vectors of size 3 and 6 respectively for example from mma operation results. + * Unfortunately, unsigned int __attribute__((ext_vector_type(6))) a.k.a + * NativeVectorT is NOT the same as __attribute__((__vector_size__(6 * + * sizeof(unsigned int)))) unsigned int which is returned from the mma ops despite being + * functionally equivalent. This class may be trivially constructed from both, so we can steer + * the templated ctor below to only consider incoming vectors types other than our two storage + * types of interest. + */ + using storage_type_alias = + element_type __attribute__((__vector_size__(sizeof(element_type) * vector_size))); + __host__ __device__ constexpr f6_pk_t() {} __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} { // TODO: consider removing initialization similar to vector_type } - // Initialize from a vector type with the same size as packed_size - template ::vector_size == packed_size>> + // Initialize from a vector type with the same size as packed_size. + // Exclude storage_type and storage_type_alias because these are trivially constructible. + template < + typename T, + typename = enable_if_t && !is_same_v && + scalar_type::vector_size == packed_size>> __host__ __device__ f6_pk_t(const T& v) { + static_assert(scalar_type::vector_size == packed_size, + "Input vector size must match packed_size."); static_for<0, packed_size, 1>{}( [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); } // Broadcast single initialization value to all packed elements __host__ __device__ f6_pk_t(const int8_t v) - : f6_pk_t(static_cast(v)) + : f6_pk_t(static_cast>(v)) { // TODO: consider removing initialization similar to vector_type } @@ -191,27 +283,6 @@ struct pk_i4_t __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} }; -inline constexpr auto next_pow2(uint32_t x) -{ - // Precondition: x > 1. - return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; -} - -// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, -// native types: bool -template -inline constexpr bool is_native_type() -{ - return is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same_v || is_same_v || is_same::value; -} - -// scalar_type -template -struct scalar_type; - // is_scalar_type template struct is_scalar_type @@ -224,14 +295,13 @@ template using has_same_scalar_type = is_same>::type, typename scalar_type>::type>; -template -struct scalar_type +template <> +struct scalar_type { - using type = T; - static constexpr index_t vector_size = N; + using type = bool; + static constexpr index_t vector_size = 1; }; -// template <> struct scalar_type { @@ -293,35 +363,35 @@ struct scalar_type template <> struct scalar_type { - using type = pk_i4_t; + using type = typename pk_i4_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_fnuz_t::data_type; + using type = typename f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t::data_type; + using type = typename bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f8_ocp_t::data_type; + using type = typename f8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_ocp_t::data_type; + using type = typename bf8_ocp_t::data_type; static constexpr index_t vector_size = 1; }; @@ -329,7 +399,7 @@ struct scalar_type template <> struct scalar_type { - using type = e8m0_bexp_t::type; + using type = typename e8m0_bexp_t::type; static constexpr index_t vector_size = 1; }; #endif @@ -337,42 +407,35 @@ struct scalar_type template <> struct scalar_type { - using type = f4x2_pk_t::type; + using type = typename f4x2_pk_t::type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x32_pk_t::storage_type; + using type = typename f6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x32_pk_t::storage_type; + using type = typename bf6x32_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = f6x16_pk_t::storage_type; + using type = typename f6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf6x16_pk_t::storage_type; - static constexpr index_t vector_size = 1; -}; - -template <> -struct scalar_type -{ - using type = bool; + using type = typename bf6x16_pk_t::storage_type; static constexpr index_t vector_size = 1; }; diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index 204b199629..b6a199bd4a 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -3,1373 +3,34 @@ #pragma once #include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wlifetime-safety-intra-tu-suggestions" namespace ck { -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; - -template -using vector_type_maker_t = typename vector_type_maker::type; - -template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - -template -struct vector_type()>> -{ - using d1_t = T; - using type = d1_t; - - union - { - T d1_; - StaticallyIndexedArray d1x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } -}; - __device__ int static err = 0; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - - using type = d3_t; - - union - { - d3_t d3_; - StaticallyIndexedArray d1x3_; - StaticallyIndexedArray d2x1_; - StaticallyIndexedArray d3x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; - - union - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d5_t __attribute__((ext_vector_type(5))); - - using type = d5_t; - - union - { - d5_t d5_; - StaticallyIndexedArray d1x5_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d5x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - typedef T d6_t __attribute__((ext_vector_type(6))); - - using type = d6_t; - - union - { - d6_t d6_; - StaticallyIndexedArray d1x6_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d3x2_; - StaticallyIndexedArray d6x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x6_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d3x2_; - } - else if constexpr(is_same::value) - { - return data_.d6x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x6_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d3x2_; - } - else if constexpr(is_same::value) - { - return data_.d6x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d7_t __attribute__((ext_vector_type(7))); - - using type = d7_t; - - union - { - d7_t d7_; - StaticallyIndexedArray d1x7_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d7x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d13_t __attribute__((ext_vector_type(13))); - - using type = d13_t; - - union - { - d13_t d13_; - StaticallyIndexedArray d1x13_; - StaticallyIndexedArray d4x3_; - StaticallyIndexedArray d8x1_; - StaticallyIndexedArray d13x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - - using type = d32_t; - - union - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_ = {d32_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - // __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - // __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - - using type = d64_t; - - union - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - - using type = d128_t; - - union - { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; - } data_ = {d128_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - - using type = d256_t; - - union - { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; - } data_ = {d256_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } -}; template struct non_native_vector_base; -template -struct nnvb_data_t_selector -{ - using type = unsigned _BitInt(8 * sizeof(T)); -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_ocp_t::data_type; -}; - -#ifndef CK_CODE_GEN_RTC -template <> -struct nnvb_data_t_selector -{ - using type = f8_fnuz_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_fnuz_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = e8m0_bexp_t::type; -}; -#endif - -template <> -struct nnvb_data_t_selector -{ - using type = f6x16_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x32_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x16_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x32_pk_t::storage_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = pk_i4_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f4x2_pk_t::type; -}; - template struct non_native_vector_base< T, N, ck::enable_if_t> { - using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + using data_t = typename scalar_type::type; // select data_t based on the size of T static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - using data_v = data_t __attribute__((ext_vector_type(N))); + using data_v = NativeVectorT; using type = non_native_vector_base; - union alignas(next_pow2(N * sizeof(T))) + union alignas(math::next_power_of_two()) { data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} @@ -1405,7 +66,7 @@ struct non_native_vector_base< } template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] { static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); @@ -1460,20 +121,19 @@ struct non_native_vector_base< N, ck::enable_if_t> { - using data_t = - typename nnvb_data_t_selector::type; // select data_t based on declared base type + using data_t = typename scalar_type::type; // select data_t based on declared base type using element_t = typename T::element_type; // select element_t based on declared element type static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); - using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); - using type = non_native_vector_base; + using data_v = NativeVectorT; + using type = non_native_vector_base; - union alignas(next_pow2(N * sizeof(T))) + union alignas(math::next_power_of_two()) { data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; + StaticallyIndexedArray_v2 dxN; + StaticallyIndexedArray_v2 dTxN; + StaticallyIndexedArray_v2 dNx1; } data_; // Broadcast single value to vector @@ -1512,7 +172,31 @@ struct non_native_vector_base< } template - __host__ __device__ constexpr const auto& AsType() const + __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dNx1; + } + else if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { static_assert(is_same_v || is_same_v || is_same_v, "Something went wrong, please check src and dst types."); @@ -1556,594 +240,378 @@ struct scalar_type::size_factor; }; -// non-native vector_type implementation +/** + * @brief Helper struct to determine the storage type for vector_type + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @tparam Enable SFINAE helper + */ +template +struct vector_type_storage; + +/** + * @brief Vector storage type for native scalar types. + * @tparam T The element type of the vector + * @note For Rank = 1 and native types, the storage type is simply T itself (scalar) + */ template -struct vector_type()>> +struct vector_type_storage()>> { - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using type = d1_nnv_t; - - union alignas(next_pow2(1 * sizeof(T))) - { - d1_t d1_; - StaticallyIndexedArray d1x1_; - d1_nnv_t d1_nnv_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } + using type = T; }; -template -struct vector_type()>> +/** + * @brief Vector storage type for native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * + * Assigns a native vector type based on the element type and rank. + * For boolean types, uses a C-style array `T[Rank]`, otherwise uses + * the `NativeVectorT` template specialization. + * + * @note Special handling note: + * Sub-byte sizes such as bool have different sizes in ext_vector_type (via NativeVectorT) vs array + * types due to packing. Builtin vector types pack bool elements, while C++ arrays use 1 byte per + * bool as a standard (minimum write size = 1 byte). e.g., ext_vector_type(bool, 4) is packed as + * minimum 1 byte, while bool[4] is 4 bytes. vector_type::AsType, aliases with + * StaticallyIndexedArray_v2 which is C-style array under the hood, so we need to avoid using + * ext_vector_type with bool due to potential for data slicing errors. + */ +template +struct vector_type_storage() && (Rank > 1)>> { - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - - using type = d2_t; - - union alignas(next_pow2(2 * sizeof(T))) - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const [[clang::lifetimebound]] - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } + using type = std::conditional_t, T[Rank], NativeVectorT>; }; -template -struct vector_type()>> +/** + * @brief Vector storage type for non-native vector types. + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + * @note For non-native types, the storage type is non_native_vector_base + */ +template +struct vector_type_storage()>> { - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - - using type = d4_t; - - union alignas(next_pow2(4 * sizeof(T))) - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } + using type = non_native_vector_base; }; -template -struct vector_type()>> +/** + * @brief Convenience wrapper for vector_type_storage + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +using vector_type_storage_t = typename vector_type_storage::type; + +/** + * @brief Trait to check whether one vector storage class is the same as another (e.g., same scalar, + * or same vector class). + * @tparam Lhs The source storage type + * @tparam Rhs The comparator storage type + * + * Same storage classes are: + * - Same type + * - Same template vector types with matching base type (may have different ranks) + * - C-style arrays of same base type (may have different ranks) + */ +template +struct is_same_vector_storage_class : public false_type { - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; +}; - using type = d8_t; +/** + * @brief Template native vector types of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + */ +template +struct is_same_vector_storage_class, NativeVectorT> + : true_type +{ +}; - union alignas(next_pow2(8 * sizeof(T))) - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; +/** + * @brief Template non-native vector types of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + * @tparam Enable SFINAE helper + */ +template +struct is_same_vector_storage_class, + non_native_vector_base> : true_type +{ +}; - __host__ __device__ constexpr vector_type() : data_{type{}} {} +/** + * @brief C-style arrays of same base type with different ranks + * @tparam T The base element type + * @tparam LhsRank The rank of the source type + * @tparam RhsRank The rank of the comparator type + */ +template +struct is_same_vector_storage_class : true_type +{ +}; - __host__ __device__ constexpr vector_type(type v) : data_{v} {} +/** + * @brief Convenience evaluator for is_same_vector_storage_class + * @tparam Lhs The source storage type + * @tparam Rhs The comparator storage type + */ +template +static constexpr bool is_same_vector_storage_class_v = + is_same_vector_storage_class::value; +// Fwd declaration +template +struct vector_type; + +/** + * @brief Trait to extract element type and rank from vector_type and related types + * @tparam T The vector type + */ +template +struct vector_type_traits +{ + using element_type = T; + static constexpr index_t Rank = 1; +}; + +/** + * @brief Specialization of vector_type_traits for vector_type + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Specialization of vector_type_traits for non_native_vector_base + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Specialization of vector_type_traits for NativeVectorT + * @tparam T The element type of the vector + * @tparam Rank_ The number of elements in the vector + */ +template +struct vector_type_traits> +{ + using element_type = T; + static constexpr index_t Rank = Rank_; +}; + +/** + * @brief Vector type wrapper + * @tparam T The element type of the vector + * @tparam Rank The number of elements in the vector + */ +template +struct vector_type +{ + /// @brief Internal storage type for vector_type. + using StorageT = vector_type_storage_t; + using type = StorageT; + StorageT data_; + + /// @brief Default constructor for vector_type + __host__ __device__ constexpr vector_type() : data_{} {} + + /// @brief Constructor for native vector initialization + __host__ __device__ constexpr vector_type(StorageT v) : data_{v} {} + + /** + * @brief Validates whether a type can be used in an AsType cast operation for vector_type + * class. + * + * This function checks if a given type X can be legally used as an alias to either reinterpret + * or slice (iterate) through the local storage type StorageT. The validation ensures type + * safety and structural compatibility between the source and target vector types. + * + * @tparam X The target type to validate for AsType casting. + * + * @return constexpr bool True if the type is valid for AsType casting, false otherwise + * + * @note Requirements for a valid AsType cast on vector_type: + * 1. The value type of X must match the storage value type (T) + * 2. X must be either: + * a) A scalar type (T) where RankX == 1, OR + * b) A vector class that matches the storage vector class (e.g., both are + * NativeVectorT or non_native_vector_base) where: + * - RankX is a power of 2, OR + * RankX == 3, OR + * RankX == Storage Rank + * - RankX must be <= Storage Rank + * @example + * auto srcVec = vector_type{}; // T = float, Rank = 8, native vector storage + * auto result = srcVec.AsType(); // Where datatype X could be: + * X = NativeVectorT; // OK: native vector T, RankX = 4 (power of 2) + * X = float; // OK: scalar T, RankX = 1 + * X = NativeVectorT; // ERROR: RankX not a power of 2, ==3, or ==Rank + * X = int; // ERROR: Invalid scalar cast, T != int + * X = float[4]; // ERROR: Invalid type, storage vector class doesn't + * // match (native vector != C-array) + */ template - __host__ __device__ constexpr const auto& AsType() const + static constexpr bool is_as_type_cast_valid() { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); + using TraitsX = vector_type_traits; - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } + // Checks storage classes match, with same base type (may have different ranks) + constexpr bool is_valid_cast = + is_same_vector_storage_class_v || // Matching vector storage + is_same_v; // Matching scalar type + + // Validate vector ranks + constexpr bool is_valid_rank = (math::is_power_of_two_integer(TraitsX::Rank) || + (TraitsX::Rank == 3) || (TraitsX::Rank == Rank)) && + (TraitsX::Rank <= Rank); + + return is_valid_cast && is_valid_rank; } + /** + * @brief Allows casting the vector_type to another type X via aliasing or slicing. + * Use cases are to expose the internal storage type, or to slice the vector into smaller + * vectors for iteration purposes. + * @tparam X The target type to validate for AsType casting. + * @returns a reference to the reinterpreted data as StaticallyIndexedArray_v2. + * Rigid control of allowable casts is enforced via static_assert to ensure type safety. + * See is_as_type_cast_valid() for requirements. + */ + template + __host__ __device__ constexpr auto const& AsType() const [[clang::lifetimebound]] + { + // Make this a hard error if the datatype X is not a valid cast. + static_assert(is_as_type_cast_valid(), "Datatype X is not a valid AsType cast"); + + using TraitsX = vector_type_traits; + + // Calculate the new rank after slicing. + // Note: We might end up with incomplete quantization from slicing + // when Rank % TraitsX::Rank != 0, so take the floor division. + constexpr index_t newRank = Rank / TraitsX::Rank; + + // Determine the cast type: + // - Scalar T if slicing to scalar or vector size of 1, + // - X otherwise. + using CastT = conditional_t; + using ResultT = StaticallyIndexedArray_v2; + + // As a rule, the aliasing type should not be larger than the original type. + static_assert(sizeof(ResultT) <= sizeof(vector_type), + "Resulting aliasing cannot be larger than original type"); + + // Re-cast as vectorized type. + return *(bit_cast(this)); + } + + /** + * @brief Allows casting the vector_type to another type X via aliasing or slicing. + * Use cases are to expose the internal storage type, or to slice the vector into smaller + * vectors for iteration purposes. + * @tparam X The target type to validate for AsType casting. + * @returns a reference to the reinterpreted data as StaticallyIndexedArray_v2. + * Rigid control of allowable casts is enforced via static_assert to ensure type safety. + * See is_as_type_cast_valid() for requirements. + */ template __host__ __device__ constexpr auto& AsType() [[clang::lifetimebound]] { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); + // Make this a hard error if the datatype X is not a valid cast. + static_assert(is_as_type_cast_valid(), "Datatype X is not a valid AsType cast"); - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } + using TraitsX = vector_type_traits; + + // Calculate the new rank after slicing. + // Note: We might end up with incomplete quantization from slicing + // when Rank % TraitsX::Rank != 0, so take the floor division. + constexpr index_t newRank = Rank / TraitsX::Rank; + + // Determine the cast type: + // - Scalar T if slicing to scalar or vector size of 1, + // - X otherwise. + using CastT = conditional_t; + using ResultT = StaticallyIndexedArray_v2; + + // As a rule, the aliasing type should not be larger than the original type. + static_assert(sizeof(ResultT) <= sizeof(vector_type), + "Resulting aliasing cannot be larger than original type"); + + // Re-cast as vectorized type. + return *(bit_cast(this)); } }; -template -struct vector_type()>> +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +/** + * @brief scalar_type trait override for vector_type + * @tparam T The vector type + * @tparam N The number of elements in the vector + */ +template +struct scalar_type> { - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - - using type = d16_t; - - union alignas(next_pow2(16 * sizeof(T))) - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } + using type = typename scalar_type::type; + static constexpr index_t vector_size = N; }; -template -struct vector_type()>> +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - - using type = d32_t; - - union alignas(next_pow2(32 * sizeof(T))) - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } + using type = vector_type; }; -template -struct vector_type()>> +template +struct vector_type_maker, N0> { - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - using d64_t = non_native_vector_base; - - using type = d64_t; - - union alignas(next_pow2(64 * sizeof(T))) - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } + using type = vector_type; }; +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index 4e477eed26..00fab270e8 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -260,7 +260,10 @@ struct DynamicBuffer x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); } else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && - is_same>::type, int8_t>::value && + is_same_v>::type, int8_t> && + !is_same_v, + pk_i4_t> && // TODO: This needs to be fixed for pk_i4_t which + // cannot be handled below, but is stored as int8_t workaround_int8_ds_write_issue) { if(is_valid_element) diff --git a/include/ck/utility/math.hpp b/include/ck/utility/math.hpp index b2ebf4b371..ff0d22b1a8 100644 --- a/include/ck/utility/math.hpp +++ b/include/ck/utility/math.hpp @@ -222,16 +222,28 @@ template __host__ __device__ constexpr auto next_power_of_two() { // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail - constexpr index_t Y = 1 << (32 - __builtin_clz(X - 1)); + constexpr index_t Y = X > 1 ? (1 << (32 - __builtin_clz(X - 1))) : X; return Y; } template -__host__ __device__ constexpr auto next_power_of_two(Number x) +__host__ __device__ constexpr auto next_power_of_two(Number) { - // TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail - constexpr index_t Y = 1 << (32 - __builtin_clz(x.value - 1)); - return Number{}; + return Number()>{}; +} + +__host__ __device__ constexpr int32_t integer_log2_floor(int32_t x) +{ + // x valid for 1 ~ 0x7fffffff + // __builtin_clz will produce unexpected result if x is 0; + return (x > 0) ? (31 - __builtin_clz(x)) : -1; +} + +__host__ __device__ constexpr bool is_power_of_two_integer(int32_t x) +{ + // x valid for 1 ~ 0x7fffffff + // Powers of 2 always positive + return (x > 0) ? !(x & (x - 1)) : false; } } // namespace math diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 161c4d37c3..11f0053585 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -1841,7 +1841,7 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 float float_array[32]; } in{x}; - using array_type = uint8_t __attribute__((ext_vector_type(32))); + using array_type = NativeVectorT; array_type uint8_array; // collect the 6-bit values into an array @@ -2178,7 +2178,7 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 float float_array[32]; } in{x}; - using array_type = uint8_t __attribute__((ext_vector_type(32))); + using array_type = NativeVectorT; array_type uint8_array; // collect the 6-bit values into an array