Files
composable_kernel/include/ck/utility/math.hpp
Christopher Millette 04eddbc5ce [rocm-libraries] ROCm/rocm-libraries#4471 (commit 10fa702)
[CK] Optimize vector type build times

**Supercedes https://github.com/ROCm/rocm-libraries/pull/4281 due to CI
issues on import**

## Proposed changes

Build times can be affected by many different things and is highly
attributed to the way we write and use the code. Two critical areas of
the builds are **frontend parsing** and **backend codegen and
compilation**.

### Frontend Parsing
The length of the code, the include header tree and macro expansions all
affect the front-end parsing time.
This PR seeks to reduce the parsing time of the dtype_vector.hpp
vector_type class by reducing redundant code by generalization.
* Partial specializations of vector_type for native and non-native
datatypes have been generalized to one single class, consolidating all
of the data initialization and AsType casting requirements into one
place.
* The class nnvb_data_t_selector (e.g., Non-native vector base dataT
selector) class has been removed and replaced with scalar_type
instantiations as they have the same purpose. Scalar type class' purpose
is already to map generalized datatypes to native types compatible with
ext_vector_t.

### Backend Codegen
Template instantiation behavior can also affect build times. Recursive
instantiations are very slow versus concrete instantiations. The
compiler must make multiple passes to expand template instantiations so
we need to be careful about how they are used.
* Previous vector_type classes declared a union storage class, which
aliases StaticallyIndexedArray<T,N>.
```
template <typename T>
struct vector_type<T, 4, typename ck::enable_if_t<is_native_type<T>()>>
{
    using d1_t = T;
    typedef T d2_t __attribute__((ext_vector_type(2)));
    typedef T d4_t __attribute__((ext_vector_type(4)));

    using type = d4_t;

    union
    {
        d4_t d4_;
        StaticallyIndexedArray<d1_t, 4> d1x4_;
        StaticallyIndexedArray<d2_t, 2> d2x2_;
        StaticallyIndexedArray<d4_t, 1> d4x1_;
    } data_;
   ...
};
```
* Upon further inspection, StaticallyIndexedArray is built on-top of a
recursive Tuple concatenation.
```
template <typename T, index_t N>
struct StaticallyIndexedArrayImpl
{
    using type =
        typename tuple_concat<typename StaticallyIndexedArrayImpl<T, N / 2>::type,
                              typename StaticallyIndexedArrayImpl<T, N - N / 2>::type>::type;
};
```
This union storage has been removed from the vector_type storage class.

* Further references to StaticallyIndexedArray have been replaced with
StaticallyIndexedArray_v2, which is a concrete implementation using
C-style arrays.
```
template <typename T, index_t N>
struct StaticallyIndexedArray_v2
{
    ...

    T data_[N];
};
```

### Fixes
* Using bool datatype with vector_type was previously error prone. Bool,
as a native datatype would be stored into bool ext_vector_type(N) for
storage, which is a packed datatype. Meaning that for example,
sizeof(bool ext_vector_type(4)) == 1, which does not equal
sizeof(StaticallyIndexedArray<bool ext_vector_type(1), 4> == 4. The
union of these datatypes has incorrect data slicing, meaning that the
bits location of the packed bool do not match with the
StaticallyIndexedArray member. As such, vector_type will use C-Style
array storage for bool type instead of ext_vector_type.
```
template <typename T, index_t Rank>
using NativeVectorT = T __attribute__((ext_vector_type(Rank)));

sizeof(NativeVectorT<bool, 4>) == 1  (1 byte per 4 bool - packed)
element0 = bit 0 of byte 0
element1 = bit 1 of byte 0
element2 = bit 2 of byte 0
element3 = bit 3 of byte 0

sizeof(StaticallyIndexedArray[NativeVectorT<bool, 1>, 4] == 4  (1 byte per bool)
element0 = bit 0 of byte 0
element1 = bit 0 of byte 1
element1 = bit 0 of byte 2
element1 = bit 0 of byte 3

union{
    NativeVectorT<bool, 4> d1_t;
    ...
    StaticallyIndexedArray[NativeVectorT<bool,1>, 4] d4x1;
};

// union size == 4 which means invalid slicing!
```
* Math utilities such as next_power_of_two addressed for invalid cases
of X < 2
* Remove redundant implementation of next_pow2

### Additions
* integer_log2_floor to math.hpp
* is_power_of_two_integer to math.hpp

### Build Time Analysis

Machine:  banff-cyxtera-s78-2
Target: gfx942

| Build Target | Threads | Frontend Parse Time (s) | Backend Codegen
Time (s) | TotalTime (s) | commitId |

|---------------|---------|-------------------------|--------------------------|---------------|
2026-02-11 19:01:05 +00:00

251 lines
5.2 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace ck {
namespace math {
template <typename T, T s>
struct scales
{
__host__ __device__ constexpr T operator()(T a) const { return s * a; }
};
template <typename T>
struct plus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
};
template <typename T>
struct minus
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a - b; }
};
struct multiplies
{
template <typename A, typename B>
__host__ __device__ constexpr auto operator()(const A& a, const B& b) const
{
return a * b;
}
};
template <typename T>
struct maximize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; }
};
template <typename T>
struct minimize
{
__host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; }
};
template <typename T>
struct integer_divide_ceiler
{
__host__ __device__ constexpr T operator()(T a, T b) const
{
static_assert(is_same<T, index_t>{} || is_same<T, int>{}, "wrong type");
return (a + b - Number<1>{}) / b;
}
};
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
{
return x / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
{
return (x + y - Number<1>{}) / y;
}
template <typename X, typename Y>
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
{
return y * integer_divide_ceil(x, y);
}
template <typename T>
__host__ __device__ constexpr T max(T x)
{
return x;
}
template <typename T>
__host__ __device__ constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t max(Number<X>, index_t y)
{
return X > y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t max(index_t x, Number<Y>)
{
return x > Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto max(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return max(x, max(ys...));
}
template <typename T>
__host__ __device__ constexpr T min(T x)
{
return x;
}
template <typename T>
__host__ __device__ constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <index_t X>
__host__ __device__ constexpr index_t min(Number<X>, index_t y)
{
return X < y ? X : y;
}
template <index_t Y>
__host__ __device__ constexpr index_t min(index_t x, Number<Y>)
{
return x < Y ? x : Y;
}
template <typename X, typename... Ys>
__host__ __device__ constexpr auto min(X x, Ys... ys)
{
static_assert(sizeof...(Ys) > 0, "not enough argument");
return min(x, min(ys...));
}
template <typename T>
__host__ __device__ constexpr T clamp(const T& x, const T& lowerbound, const T& upperbound)
{
return min(max(x, lowerbound), upperbound);
}
// greatest common divisor, aka highest common factor
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
{
if(x < 0)
{
return gcd(-x, y);
}
else if(y < 0)
{
return gcd(x, -y);
}
else if(x == y || x == 0)
{
return y;
}
else if(y == 0)
{
return x;
}
else if(x > y)
{
return gcd(x % y, y);
}
else
{
return gcd(x, y % x);
}
}
template <index_t X, index_t Y>
__host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
{
constexpr auto r = gcd(X, Y);
return Number<r>{};
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
{
return gcd(x, gcd(ys...));
}
// least common multiple
template <typename X, typename Y>
__host__ __device__ constexpr auto lcm(X x, Y y)
{
return (x * y) / gcd(x, y);
}
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
{
return lcm(x, lcm(ys...));
}
template <typename T>
struct equal
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; }
};
template <typename T>
struct less
{
__host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; }
};
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two()
{
// TODO: X need to be 2 ~ 0x7fffffff. 0, 1, or larger than 0x7fffffff will compile fail
constexpr index_t Y = X > 1 ? (1 << (32 - __builtin_clz(X - 1))) : X;
return Y;
}
template <index_t X>
__host__ __device__ constexpr auto next_power_of_two(Number<X>)
{
return Number<next_power_of_two<X>()>{};
}
__host__ __device__ constexpr int32_t integer_log2_floor(int32_t x)
{
// x valid for 1 ~ 0x7fffffff
// __builtin_clz will produce unexpected result if x is 0;
return (x > 0) ? (31 - __builtin_clz(x)) : -1;
}
__host__ __device__ constexpr bool is_power_of_two_integer(int32_t x)
{
// x valid for 1 ~ 0x7fffffff
// Powers of 2 always positive
return (x > 0) ? !(x & (x - 1)) : false;
}
} // namespace math
} // namespace ck