mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-20 07:07:43 +00:00
[CK] Optimize vector type build times **Supercedes https://github.com/ROCm/rocm-libraries/pull/4281 due to CI issues on import** ## Proposed changes Build times can be affected by many different things and is highly attributed to the way we write and use the code. Two critical areas of the builds are **frontend parsing** and **backend codegen and compilation**. ### Frontend Parsing The length of the code, the include header tree and macro expansions all affect the front-end parsing time. This PR seeks to reduce the parsing time of the dtype_vector.hpp vector_type class by reducing redundant code by generalization. * Partial specializations of vector_type for native and non-native datatypes have been generalized to one single class, consolidating all of the data initialization and AsType casting requirements into one place. * The class nnvb_data_t_selector (e.g., Non-native vector base dataT selector) class has been removed and replaced with scalar_type instantiations as they have the same purpose. Scalar type class' purpose is already to map generalized datatypes to native types compatible with ext_vector_t. ### Backend Codegen Template instantiation behavior can also affect build times. Recursive instantiations are very slow versus concrete instantiations. The compiler must make multiple passes to expand template instantiations so we need to be careful about how they are used. * Previous vector_type classes declared a union storage class, which aliases StaticallyIndexedArray<T,N>. ``` template <typename T> struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>> { using d1_t = T; typedef T d2_t __attribute__((ext_vector_type(2))); typedef T d4_t __attribute__((ext_vector_type(4))); using type = d4_t; union { d4_t d4_; StaticallyIndexedArray<d1_t, 4> d1x4_; StaticallyIndexedArray<d2_t, 2> d2x2_; StaticallyIndexedArray<d4_t, 1> d4x1_; } data_; ... }; ``` * Upon further inspection, StaticallyIndexedArray is built on-top of a recursive Tuple concatenation. ``` template <typename T, index_t N> struct StaticallyIndexedArrayImpl { using type = typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type, typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type; }; ``` This union storage has been removed from the vector_type storage class. * Further references to StaticallyIndexedArray have been replaced with StaticallyIndexedArray_v2, which is a concrete implementation using C-style arrays. ``` template <typename T, index_t N> struct StaticallyIndexedArray_v2 { ... T data_[N]; }; ``` ### Fixes * Using bool datatype with vector_type was previously error prone. Bool, as a native datatype would be stored into bool ext_vector_type(N) for storage, which is a packed datatype. Meaning that for example, sizeof(bool ext_vector_type(4)) == 1, which does not equal sizeof(StaticallyIndexedArray<bool ext_vector_type(1), 4> == 4. The union of these datatypes has incorrect data slicing, meaning that the bits location of the packed bool do not match with the StaticallyIndexedArray member. As such, vector_type will use C-Style array storage for bool type instead of ext_vector_type. ``` template <typename T, index_t Rank> using NativeVectorT = T __attribute__((ext_vector_type(Rank))); sizeof(NativeVectorT<bool, 4>) == 1 (1 byte per 4 bool - packed) element0 = bit 0 of byte 0 element1 = bit 1 of byte 0 element2 = bit 2 of byte 0 element3 = bit 3 of byte 0 sizeof(StaticallyIndexedArray[NativeVectorT<bool, 1>, 4] == 4 (1 byte per bool) element0 = bit 0 of byte 0 element1 = bit 0 of byte 1 element1 = bit 0 of byte 2 element1 = bit 0 of byte 3 union{ NativeVectorT<bool, 4> d1_t; ... StaticallyIndexedArray[NativeVectorT<bool,1>, 4] d4x1; }; // union size == 4 which means invalid slicing! ``` * Math utilities such as next_power_of_two addressed for invalid cases of X < 2 * Remove redundant implementation of next_pow2 ### Additions * integer_log2_floor to math.hpp * is_power_of_two_integer to math.hpp ### Build Time Analysis Machine: banff-cyxtera-s78-2 Target: gfx942 | Build Target | Threads | Frontend Parse Time (s) | Backend Codegen Time (s) | TotalTime (s) | commitId | |---------------|---------|-------------------------|--------------------------|---------------|
251 lines
5.2 KiB
C++
251 lines
5.2 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "integral_constant.hpp"
|
|
#include "number.hpp"
|
|
#include "type.hpp"
|
|
#include "enable_if.hpp"
|
|
|
|
namespace ck {
|
|
namespace math {
|
|
|
|
template <typename T, T s>
|
|
struct scales
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct plus
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct minus
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
|
|
};
|
|
|
|
struct multiplies
|
|
{
|
|
template <typename A, typename B>
|
|
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
|
|
{
|
|
return a * b;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct maximize
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct minimize
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct integer_divide_ceiler
|
|
{
|
|
__host__ __device__ constexpr T operator()(T a, T b) const
|
|
{
|
|
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
|
|
|
|
return (a + b - Number<1>{}) / b;
|
|
}
|
|
};
|
|
|
|
template <typename X, typename Y>
|
|
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
|
|
{
|
|
return x / y;
|
|
}
|
|
|
|
template <typename X, typename Y>
|
|
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
|
|
{
|
|
return (x + y - Number<1>{}) / y;
|
|
}
|
|
|
|
template <typename X, typename Y>
|
|
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
|
|
{
|
|
return y * integer_divide_ceil(x, y);
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr T max(T x)
|
|
{
|
|
return x;
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr T max(T x, T y)
|
|
{
|
|
return x > y ? x : y;
|
|
}
|
|
|
|
template <index_t X>
|
|
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
|
|
{
|
|
return X > y ? X : y;
|
|
}
|
|
|
|
template <index_t Y>
|
|
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
|
|
{
|
|
return x > Y ? x : Y;
|
|
}
|
|
|
|
template <typename X, typename... Ys>
|
|
__host__ __device__ constexpr auto max(X x, Ys... ys)
|
|
{
|
|
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
|
|
|
return max(x, max(ys...));
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr T min(T x)
|
|
{
|
|
return x;
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr T min(T x, T y)
|
|
{
|
|
return x < y ? x : y;
|
|
}
|
|
|
|
template <index_t X>
|
|
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
|
|
{
|
|
return X < y ? X : y;
|
|
}
|
|
|
|
template <index_t Y>
|
|
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
|
|
{
|
|
return x < Y ? x : Y;
|
|
}
|
|
|
|
template <typename X, typename... Ys>
|
|
__host__ __device__ constexpr auto min(X x, Ys... ys)
|
|
{
|
|
static_assert(sizeof...(Ys) > 0, "not enough argument");
|
|
|
|
return min(x, min(ys...));
|
|
}
|
|
|
|
template <typename T>
|
|
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
|
|
{
|
|
return min(max(x, lowerbound), upperbound);
|
|
}
|
|
|
|
// greatest common divisor, aka highest common factor
|
|
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
|
{
|
|
if(x < 0)
|
|
{
|
|
return gcd(-x, y);
|
|
}
|
|
else if(y < 0)
|
|
{
|
|
return gcd(x, -y);
|
|
}
|
|
else if(x == y || x == 0)
|
|
{
|
|
return y;
|
|
}
|
|
else if(y == 0)
|
|
{
|
|
return x;
|
|
}
|
|
else if(x > y)
|
|
{
|
|
return gcd(x % y, y);
|
|
}
|
|
else
|
|
{
|
|
return gcd(x, y % x);
|
|
}
|
|
}
|
|
|
|
template <index_t X, index_t Y>
|
|
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
|
{
|
|
constexpr auto r = gcd(X, Y);
|
|
|
|
return Number<r>{};
|
|
}
|
|
|
|
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
|
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
|
{
|
|
return gcd(x, gcd(ys...));
|
|
}
|
|
|
|
// least common multiple
|
|
template <typename X, typename Y>
|
|
__host__ __device__ constexpr auto lcm(X x, Y y)
|
|
{
|
|
return (x * y) / gcd(x, y);
|
|
}
|
|
|
|
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
|
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
|
{
|
|
return lcm(x, lcm(ys...));
|
|
}
|
|
|
|
template <typename T>
|
|
struct equal
|
|
{
|
|
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct less
|
|
{
|
|
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
|
|
};
|
|
|
|
template <index_t X>
|
|
__host__ __device__ constexpr auto next_power_of_two()
|
|
{
|
|
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
|
|
constexpr index_t Y = X > 1 ? (1 << (32 - __builtin_clz(X - 1))) : X;
|
|
return Y;
|
|
}
|
|
|
|
template <index_t X>
|
|
__host__ __device__ constexpr auto next_power_of_two(Number<X>)
|
|
{
|
|
return Number<next_power_of_two<X>()>{};
|
|
}
|
|
|
|
__host__ __device__ constexpr int32_t integer_log2_floor(int32_t x)
|
|
{
|
|
// x valid for 1 ~ 0x7fffffff
|
|
// __builtin_clz will produce unexpected result if x is 0;
|
|
return (x > 0) ? (31 - __builtin_clz(x)) : -1;
|
|
}
|
|
|
|
__host__ __device__ constexpr bool is_power_of_two_integer(int32_t x)
|
|
{
|
|
// x valid for 1 ~ 0x7fffffff
|
|
// Powers of 2 always positive
|
|
return (x > 0) ? !(x & (x - 1)) : false;
|
|
}
|
|
|
|
} // namespace math
|
|
} // namespace ck
|