diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index 537d17daf7..fd9c0029bc 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -7,6 +7,7 @@ namespace ck { // float typedef float float2_t __attribute__((ext_vector_type(2))); typedef float float4_t __attribute__((ext_vector_type(4))); +typedef float float16_t __attribute__((ext_vector_type(16))); typedef float float32_t __attribute__((ext_vector_type(32))); // float16 diff --git a/composable_kernel/include/utility/float_type.hpp b/composable_kernel/include/utility/float_type.hpp deleted file mode 100644 index fd9c0029bc..0000000000 --- a/composable_kernel/include/utility/float_type.hpp +++ /dev/null @@ -1,311 +0,0 @@ -#ifndef CK_FLOAT_TYPE_AMD_HPP -#define CK_FLOAT_TYPE_AMD_HPP - -namespace ck { - -// For some reason, HIP compiler need this definition to generate optimal ISA -// float -typedef float float2_t __attribute__((ext_vector_type(2))); -typedef float float4_t __attribute__((ext_vector_type(4))); -typedef float float16_t __attribute__((ext_vector_type(16))); -typedef float float32_t __attribute__((ext_vector_type(32))); - -// float16 -typedef _Float16 half2_t __attribute__((ext_vector_type(2))); -typedef _Float16 half4_t __attribute__((ext_vector_type(4))); - -// bfloat16 -typedef ushort ushort2_t __attribute__((ext_vector_type(2))); -typedef ushort ushort4_t __attribute__((ext_vector_type(4))); - -template -struct vector_type -{ - typedef struct - { - T scalar[N]; - } MemoryType; -}; - -template <> -struct vector_type -{ - using MemoryType = float; - - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = float2_t; - - union DataType - { - MemoryType vector; - float scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(float s0, float s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = float4_t; - - __host__ __device__ static constexpr index_t GetSize() { return 4; } - - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half2_t; - - union DataType - { - MemoryType vector; - half scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(half s0, half s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half4_t; - - union DataType - { - MemoryType vector; - half scalar[4]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - data.scalar[2] = s2; - data.scalar[3] = s3; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort2_t; - - union DataType - { - MemoryType vector; - ushort scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(ushort s0, ushort s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort4_t; - - union DataType - { - MemoryType vector; - ushort scalar[4]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(ushort s0, ushort s1, ushort s2, ushort s3) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - data.scalar[2] = s2; - data.scalar[3] = s3; - return data.vector; - } -}; - -// data type conversion -template -struct type_convert -{ - template - __device__ T operator()(X x) const - { - return static_cast(x); - } -}; - -template <> -template <> -__device__ float type_convert::operator()(ushort x) const -{ - return bfloat16_to_float(x); -} - -template <> -template <> -__device__ ushort type_convert::operator()(float x) const -{ - return float_to_bfloat16(x); -} - -template -struct inner_product_with_conversion -{ - static constexpr auto convert = type_convert(); - - __device__ T operator()(float a, float b) const { return convert(a) * convert(b); } - - __device__ T operator()(half2_t a, half2_t b) const - { - const half* p_a_half = reinterpret_cast(&a); - const half* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - - return acc; - } - - __device__ T operator()(half4_t a, half4_t b) const - { - const half* p_a_half = reinterpret_cast(&a); - const half* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - return acc; - } - - __device__ T operator()(ushort2_t a, ushort2_t b) const - { - const ushort* p_a_bfloat16 = reinterpret_cast(&a); - const ushort* p_b_bfloat16 = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); - } - - return acc; - } - - __device__ T operator()(ushort4_t a, ushort4_t b) const - { - const ushort* p_a_bfloat16 = reinterpret_cast(&a); - const ushort* p_b_bfloat16 = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); - } - return acc; - } -}; - -} // namespace ck -#endif