mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-28 02:57:42 +00:00
[CK] Optimize vector type build times **Supercedes https://github.com/ROCm/rocm-libraries/pull/4281 due to CI issues on import** ## Proposed changes Build times can be affected by many different things and is highly attributed to the way we write and use the code. Two critical areas of the builds are **frontend parsing** and **backend codegen and compilation**. ### Frontend Parsing The length of the code, the include header tree and macro expansions all affect the front-end parsing time. This PR seeks to reduce the parsing time of the dtype_vector.hpp vector_type class by reducing redundant code by generalization. * Partial specializations of vector_type for native and non-native datatypes have been generalized to one single class, consolidating all of the data initialization and AsType casting requirements into one place. * The class nnvb_data_t_selector (e.g., Non-native vector base dataT selector) class has been removed and replaced with scalar_type instantiations as they have the same purpose. Scalar type class' purpose is already to map generalized datatypes to native types compatible with ext_vector_t. ### Backend Codegen Template instantiation behavior can also affect build times. Recursive instantiations are very slow versus concrete instantiations. The compiler must make multiple passes to expand template instantiations so we need to be careful about how they are used. * Previous vector_type classes declared a union storage class, which aliases StaticallyIndexedArray<T,N>. ``` template <typename T> struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); using type = d4_t; union { d4_t d4_; StaticallyIndexedArray<d1_t, 4> d1x4_; StaticallyIndexedArray<d2_t, 2> d2x2_; StaticallyIndexedArray<d4_t, 1> d4x1_; } data_; ... }; ``` * Upon further inspection, StaticallyIndexedArray is built on-top of a recursive Tuple concatenation. ``` template <typename T, index_t N> struct StaticallyIndexedArrayImpl { using type = typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type, typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type; }; ``` This union storage has been removed from the vector_type storage class. * Further references to StaticallyIndexedArray have been replaced with StaticallyIndexedArray_v2, which is a concrete implementation using C-style arrays. ``` template <typename T, index_t N> struct StaticallyIndexedArray_v2 { ... T data_[N]; }; ``` ### Fixes * Using bool datatype with vector_type was previously error prone. Bool, as a native datatype would be stored into bool ext_vector_type(N) for storage, which is a packed datatype. Meaning that for example, sizeof(bool ext_vector_type(4)) == 1, which does not equal sizeof(StaticallyIndexedArray<bool ext_vector_type(1), 4> == 4. The union of these datatypes has incorrect data slicing, meaning that the bits location of the packed bool do not match with the StaticallyIndexedArray member. As such, vector_type will use C-Style array storage for bool type instead of ext_vector_type. ``` template <typename T, index_t Rank> using NativeVectorT = T __attribute__((ext_vector_type(Rank))); sizeof(NativeVectorT<bool, 4>) == 1 (1 byte per 4 bool - packed) element0 = bit 0 of byte 0 element1 = bit 1 of byte 0 element2 = bit 2 of byte 0 element3 = bit 3 of byte 0 sizeof(StaticallyIndexedArray[NativeVectorT<bool, 1>, 4] == 4 (1 byte per bool) element0 = bit 0 of byte 0 element1 = bit 0 of byte 1 element1 = bit 0 of byte 2 element1 = bit 0 of byte 3 union{ NativeVectorT<bool, 4> d1_t; ... StaticallyIndexedArray[NativeVectorT<bool,1>, 4] d4x1; }; // union size == 4 which means invalid slicing! ``` * Math utilities such as next_power_of_two addressed for invalid cases of X < 2 * Remove redundant implementation of next_pow2 ### Additions * integer_log2_floor to math.hpp * is_power_of_two_integer to math.hpp ### Build Time Analysis Machine: banff-cyxtera-s78-2 Target: gfx942 | Build Target | Threads | Frontend Parse Time (s) | Backend Codegen Time (s) | TotalTime (s) | commitId | |---------------|---------|-------------------------|--------------------------|---------------|
22 KiB
22 KiB