diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index de59f200f0..0ed60df2c3 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -5,7 +5,7 @@ #define CK_AMD_INLINE_ASM_HPP #include "c_style_pointer_cast.hpp" -#include "data_type.hpp" +#include "dtype_vector.hpp" // TODO: deprecate all amd_assembly_outer_product_xxx diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 396e375d8c..0d4611becc 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -2,6 +2,7 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/dtype_fp64.hpp" namespace ck { // Define the common macro for MI300 models diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index a4d96edc6d..6b7aaf2162 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -346,53 +346,6 @@ inline constexpr bool is_native_type() is_same::value || is_same::value; } -// vector_type -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type; - -// Caution: DO NOT REMOVE -// intentionally have only declaration but no definition to cause compilation failure when trying to -// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of -// vectors" -template -struct vector_type, N>; - -// vector_type_maker -// This is the right way to handle "vector of vectors": making a bigger vector instead -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker -{ - using type = vector_type; -}; - -template -struct vector_type_maker, N0> -{ - using type = vector_type; -}; - -template -using vector_type_maker_t = typename vector_type_maker::type; - -template -__host__ __device__ constexpr auto make_vector_type(Number) -{ - return typename vector_type_maker::type{}; -} - // scalar_type template struct scalar_type; @@ -416,13 +369,6 @@ struct scalar_type static constexpr index_t vector_size = N; }; -template -struct scalar_type> -{ - using type = T; - static constexpr index_t vector_size = N; -}; - // template <> struct scalar_type @@ -524,2868 +470,10 @@ struct scalar_type static constexpr index_t vector_size = 1; }; -template -struct vector_type()>> -{ - using d1_t = T; - using type = d1_t; - - union - { - T d1_; - StaticallyIndexedArray d1x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value, - "Something went wrong, please check src and dst types."); - - return data_.d1x1_; - } -}; - -__device__ int static err = 0; -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - - using type = d2_t; - - union - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d3_t __attribute__((ext_vector_type(3))); - - using type = d3_t; - - union - { - d3_t d3_; - StaticallyIndexedArray d1x3_; - StaticallyIndexedArray d2x1_; - StaticallyIndexedArray d3x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x3_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else if constexpr(is_same::value) - { - return data_.d3x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - - using type = d4_t; - - union - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d5_t __attribute__((ext_vector_type(5))); - - using type = d5_t; - - union - { - d5_t d5_; - StaticallyIndexedArray d1x5_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d5x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x5_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d5x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d7_t __attribute__((ext_vector_type(7))); - - using type = d7_t; - - union - { - d7_t d7_; - StaticallyIndexedArray d1x7_; - StaticallyIndexedArray d2x3_; - StaticallyIndexedArray d4x1_; - StaticallyIndexedArray d7x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x7_; - } - else if constexpr(is_same::value) - { - return data_.d2x3_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else if constexpr(is_same::value) - { - return data_.d7x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - - using type = d8_t; - - union - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d13_t __attribute__((ext_vector_type(13))); - - using type = d13_t; - - union - { - d13_t d13_; - StaticallyIndexedArray d1x13_; - StaticallyIndexedArray d4x3_; - StaticallyIndexedArray d8x1_; - StaticallyIndexedArray d13x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x13_; - } - else if constexpr(is_same::value) - { - return data_.d4x3_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else if constexpr(is_same::value) - { - return data_.d13x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - - using type = d16_t; - - union - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - - using type = d32_t; - - union - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_ = {d32_t{0}}; - - __attribute__((host)) __attribute__((device)) constexpr vector_type() {} - - __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } - - // __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - // __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - - using type = d64_t; - - union - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - - using type = d128_t; - - union - { - d128_t d128_; - StaticallyIndexedArray d1x128_; - StaticallyIndexedArray d2x64_; - StaticallyIndexedArray d4x32_; - StaticallyIndexedArray d8x16_; - StaticallyIndexedArray d16x8_; - StaticallyIndexedArray d32x4_; - StaticallyIndexedArray d64x2_; - StaticallyIndexedArray d128x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x128_; - } - else if constexpr(is_same::value) - { - return data_.d2x64_; - } - else if constexpr(is_same::value) - { - return data_.d4x32_; - } - else if constexpr(is_same::value) - { - return data_.d8x16_; - } - else if constexpr(is_same::value) - { - return data_.d16x8_; - } - else if constexpr(is_same::value) - { - return data_.d32x4_; - } - else if constexpr(is_same::value) - { - return data_.d64x2_; - } - else if constexpr(is_same::value) - { - return data_.d128x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - typedef T d2_t __attribute__((ext_vector_type(2))); - typedef T d4_t __attribute__((ext_vector_type(4))); - typedef T d8_t __attribute__((ext_vector_type(8))); - typedef T d16_t __attribute__((ext_vector_type(16))); - typedef T d32_t __attribute__((ext_vector_type(32))); - typedef T d64_t __attribute__((ext_vector_type(64))); - typedef T d128_t __attribute__((ext_vector_type(128))); - typedef T d256_t __attribute__((ext_vector_type(256))); - - using type = d256_t; - - union - { - d256_t d256_; - StaticallyIndexedArray d1x256_; - StaticallyIndexedArray d2x128_; - StaticallyIndexedArray d4x64_; - StaticallyIndexedArray d8x32_; - StaticallyIndexedArray d16x16_; - StaticallyIndexedArray d32x8_; - StaticallyIndexedArray d64x4_; - StaticallyIndexedArray d128x2_; - StaticallyIndexedArray d256x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{0}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert( - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x256_; - } - else if constexpr(is_same::value) - { - return data_.d2x128_; - } - else if constexpr(is_same::value) - { - return data_.d4x64_; - } - else if constexpr(is_same::value) - { - return data_.d8x32_; - } - else if constexpr(is_same::value) - { - return data_.d16x16_; - } - else if constexpr(is_same::value) - { - return data_.d32x8_; - } - else if constexpr(is_same::value) - { - return data_.d64x4_; - } - else if constexpr(is_same::value) - { - return data_.d128x2_; - } - else if constexpr(is_same::value) - { - return data_.d256x1_; - } - else - { - return err; - } - } -}; - -template -struct non_native_vector_base; - -template -struct nnvb_data_t_selector -{ - using type = unsigned _BitInt(8 * sizeof(T)); -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf8_ocp_t::data_type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x16_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = f6x32_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x16_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = bf6x32_pk_t::type; -}; - -template <> -struct nnvb_data_t_selector -{ - using type = pk_i4_t::type; -}; - -template -struct non_native_vector_base< - T, - N, - ck::enable_if_t> -{ - using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - using data_v = data_t __attribute__((ext_vector_type(N))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return data_.dTxN; // XXX this should cause an error - } - } - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same_v || is_same_v || is_same_v, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same_v) - { - return data_.dxN; - } - else if constexpr(is_same_v) - { - return data_.dTxN; - } - else if constexpr(is_same_v) - { - return data_.dNx1; - } - else - { - return err; - } - } -}; - -// implementation for f6x16 and f6x32 -template -struct non_native_vector_base> -{ - using data_t = - typename nnvb_data_t_selector::type; // select data_t based on declared base type - using element_t = typename T::element_type; // select element_t based on declared element type - static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); - static constexpr size_t size_factor = - sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 - using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); - using type = non_native_vector_base; - - union alignas(next_pow2(N * sizeof(T))) - { - data_v dN; // storage vector; - StaticallyIndexedArray dxN; - StaticallyIndexedArray dTxN; - StaticallyIndexedArray dNx1; - } data_; - - __host__ __device__ constexpr non_native_vector_base(data_t a) - : data_{data_v(a.At(Number<0>{}))} - { - } - __host__ __device__ constexpr non_native_vector_base(T f) - : non_native_vector_base(bit_cast(f)) - { - } - __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; - __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} - - __host__ __device__ constexpr operator data_v() const { return data_.dN; } - __host__ __device__ constexpr operator data_t() const - { - if constexpr(N == 1) - { - return data_.dxN[Number<0>{}]; - } - else - { - return data_.dxN; // XXX this should cause an error - } - } - __host__ __device__ constexpr operator T() const - { - if constexpr(N == 1) - { - return data_.dTxN[Number<0>{}]; - } - else - { - return data_.dTxN; // XXX this should cause an error - } - } -}; - -template -struct scalar_type>; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -template -struct scalar_type> -{ - using type = typename non_native_vector_base::data_t; - - static constexpr index_t vector_size = N; -}; - -// non-native vector_type implementation -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using type = d1_nnv_t; - - union alignas(next_pow2(1 * sizeof(T))) - { - d1_t d1_; - StaticallyIndexedArray d1x1_; - d1_nnv_t d1_nnv_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - - using type = d2_t; - - union alignas(next_pow2(2 * sizeof(T))) - { - d2_t d2_; - StaticallyIndexedArray d1x2_; - StaticallyIndexedArray d2x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x2_; - } - else if constexpr(is_same::value) - { - return data_.d2x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - - using type = d4_t; - - union alignas(next_pow2(4 * sizeof(T))) - { - d4_t d4_; - StaticallyIndexedArray d1x4_; - StaticallyIndexedArray d2x2_; - StaticallyIndexedArray d4x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x4_; - } - else if constexpr(is_same::value) - { - return data_.d2x2_; - } - else if constexpr(is_same::value) - { - return data_.d4x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - - using type = d8_t; - - union alignas(next_pow2(8 * sizeof(T))) - { - d8_t d8_; - StaticallyIndexedArray d1x8_; - StaticallyIndexedArray d2x4_; - StaticallyIndexedArray d4x2_; - StaticallyIndexedArray d8x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x8_; - } - else if constexpr(is_same::value) - { - return data_.d2x4_; - } - else if constexpr(is_same::value) - { - return data_.d4x2_; - } - else if constexpr(is_same::value) - { - return data_.d8x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d1_nnv_t = non_native_vector_base; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - - using type = d16_t; - - union alignas(next_pow2(16 * sizeof(T))) - { - d16_t d16_; - StaticallyIndexedArray d1x16_; - StaticallyIndexedArray d2x8_; - StaticallyIndexedArray d4x4_; - StaticallyIndexedArray d8x2_; - StaticallyIndexedArray d16x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value || is_same::value) - { - return data_.d1x16_; - } - else if constexpr(is_same::value) - { - return data_.d2x8_; - } - else if constexpr(is_same::value) - { - return data_.d4x4_; - } - else if constexpr(is_same::value) - { - return data_.d8x2_; - } - else if constexpr(is_same::value) - { - return data_.d16x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - - using type = d32_t; - - union alignas(next_pow2(32 * sizeof(T))) - { - d32_t d32_; - StaticallyIndexedArray d1x32_; - StaticallyIndexedArray d2x16_; - StaticallyIndexedArray d4x8_; - StaticallyIndexedArray d8x4_; - StaticallyIndexedArray d16x2_; - StaticallyIndexedArray d32x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x32_; - } - else if constexpr(is_same::value) - { - return data_.d2x16_; - } - else if constexpr(is_same::value) - { - return data_.d4x8_; - } - else if constexpr(is_same::value) - { - return data_.d8x4_; - } - else if constexpr(is_same::value) - { - return data_.d16x2_; - } - else if constexpr(is_same::value) - { - return data_.d32x1_; - } - else - { - return err; - } - } -}; - -template -struct vector_type()>> -{ - using d1_t = T; - using d2_t = non_native_vector_base; - using d4_t = non_native_vector_base; - using d8_t = non_native_vector_base; - using d16_t = non_native_vector_base; - using d32_t = non_native_vector_base; - using d64_t = non_native_vector_base; - - using type = d64_t; - - union alignas(next_pow2(64 * sizeof(T))) - { - d64_t d64_; - StaticallyIndexedArray d1x64_; - StaticallyIndexedArray d2x32_; - StaticallyIndexedArray d4x16_; - StaticallyIndexedArray d8x8_; - StaticallyIndexedArray d16x4_; - StaticallyIndexedArray d32x2_; - StaticallyIndexedArray d64x1_; - } data_; - - __host__ __device__ constexpr vector_type() : data_{type{}} {} - - __host__ __device__ constexpr vector_type(type v) : data_{v} {} - - template - __host__ __device__ constexpr const auto& AsType() const - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } - - template - __host__ __device__ constexpr auto& AsType() - { - static_assert(is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value, - "Something went wrong, please check src and dst types."); - - if constexpr(is_same::value) - { - return data_.d1x64_; - } - else if constexpr(is_same::value) - { - return data_.d2x32_; - } - else if constexpr(is_same::value) - { - return data_.d4x16_; - } - else if constexpr(is_same::value) - { - return data_.d8x8_; - } - else if constexpr(is_same::value) - { - return data_.d16x4_; - } - else if constexpr(is_same::value) - { - return data_.d32x2_; - } - else if constexpr(is_same::value) - { - return data_.d64x1_; - } - else - { - return err; - } - } -}; - #if defined(_WIN32) using int64_t = long long; #else using int64_t = long; #endif -// fp64 -using double2_t = typename vector_type::type; -using double4_t = typename vector_type::type; - -// fp32 -using float2_t = typename vector_type::type; -using float4_t = typename vector_type::type; -using float8_t = typename vector_type::type; -using float16_t = typename vector_type::type; -using float32_t = typename vector_type::type; -using float64_t = typename vector_type::type; - -// fp16 -using half2_t = typename vector_type::type; -using half4_t = typename vector_type::type; -using half8_t = typename vector_type::type; -using half16_t = typename vector_type::type; -using half32_t = typename vector_type::type; -using half64_t = typename vector_type::type; - -// bfp16 -using bhalf2_t = typename vector_type::type; -using bhalf4_t = typename vector_type::type; -using bhalf8_t = typename vector_type::type; -using bhalf16_t = typename vector_type::type; -using bhalf32_t = typename vector_type::type; -using bhalf64_t = typename vector_type::type; - -// i32 -using int32x2_t = typename vector_type::type; -using int32x4_t = typename vector_type::type; -using int32x8_t = typename vector_type::type; -using int32x16_t = typename vector_type::type; -using int32x32_t = typename vector_type::type; -using int32x64_t = typename vector_type::type; - -// i8 -using int8x2_t = typename vector_type::type; -using int8x4_t = typename vector_type::type; -using int8x8_t = typename vector_type::type; -using int8x16_t = typename vector_type::type; -using int8x32_t = typename vector_type::type; -using int8x64_t = typename vector_type::type; - -// f8 -using f8x2_fnuz_t = typename vector_type::type; -using f8x4_fnuz_t = typename vector_type::type; -using f8x8_fnuz_t = typename vector_type::type; -using f8x16_fnuz_t = typename vector_type::type; -using f8x32_fnuz_t = typename vector_type::type; -using f8x64_fnuz_t = typename vector_type::type; - -// bf8 -using bf8x2_fnuz_t = typename vector_type::type; -using bf8x4_fnuz_t = typename vector_type::type; -using bf8x8_fnuz_t = typename vector_type::type; -using bf8x16_fnuz_t = typename vector_type::type; -using bf8x32_fnuz_t = typename vector_type::type; -using bf8x64_fnuz_t = typename vector_type::type; - -// f8 -using f8x2_ocp_t = typename vector_type::type; -using f8x4_ocp_t = typename vector_type::type; -using f8x8_ocp_t = typename vector_type::type; -using f8x16_ocp_t = typename vector_type::type; -using f8x32_ocp_t = typename vector_type::type; -using f8x64_ocp_t = typename vector_type::type; - -// bf8 -using bf8x2_ocp_t = typename vector_type::type; -using bf8x4_ocp_t = typename vector_type::type; -using bf8x8_ocp_t = typename vector_type::type; -using bf8x16_ocp_t = typename vector_type::type; -using bf8x32_ocp_t = typename vector_type::type; -using bf8x64_ocp_t = typename vector_type::type; - -#if CK_FP8_TYPE_OCP -// f8 -using f8x2_t = f8x2_ocp_t; -using f8x4_t = f8x4_ocp_t; -using f8x8_t = f8x8_ocp_t; -using f8x16_t = f8x16_ocp_t; -using f8x32_t = f8x32_ocp_t; -using f8x64_t = f8x64_ocp_t; - -// bf8 -using bf8x2_t = bf8x2_ocp_t; -using bf8x4_t = bf8x4_ocp_t; -using bf8x8_t = bf8x8_ocp_t; -using bf8x16_t = bf8x16_ocp_t; -using bf8x32_t = bf8x32_ocp_t; -using bf8x64_t = bf8x64_ocp_t; -#elif CK_FP8_TYPE_FNUZ -// f8 -using f8x2_t = f8x2_fnuz_t; -using f8x4_t = f8x4_fnuz_t; -using f8x8_t = f8x8_fnuz_t; -using f8x16_t = f8x16_fnuz_t; -using f8x32_t = f8x32_fnuz_t; -using f8x64_t = f8x64_fnuz_t; - -// bf8 -using bf8x2_t = bf8x2_fnuz_t; -using bf8x4_t = bf8x4_fnuz_t; -using bf8x8_t = bf8x8_fnuz_t; -using bf8x16_t = bf8x16_fnuz_t; -using bf8x32_t = bf8x32_fnuz_t; -using bf8x64_t = bf8x64_fnuz_t; -#endif - -// u8 -using uint8x2_t = typename vector_type::type; -using uint8x4_t = typename vector_type::type; -using uint8x8_t = typename vector_type::type; -using uint8x16_t = typename vector_type::type; -using uint8x32_t = typename vector_type::type; -using uint8x64_t = typename vector_type::type; - -// f4 -using f4x2_t = typename vector_type::type; -using f4x4_t = typename vector_type::type; -using f4x8_t = typename vector_type::type; -using f4x16_t = typename vector_type::type; -using f4x32_t = typename vector_type::type; -using f4x64_t = typename vector_type::type; - -// f6 -using f6x16_t = typename vector_type::type; -using f6x32_t = typename vector_type::type; - -// bf6 -using bf6x16_t = typename vector_type::type; -using bf6x32_t = typename vector_type::type; - -// pack int4 -using pk_i4x2_t = typename vector_type::type; -using pk_i4x4_t = typename vector_type::type; -using pk_i4x8_t = typename vector_type::type; - -#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) -template -struct NumericLimits; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } - - __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } - - __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } -}; -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } - - __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } - - __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } - - __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } - - __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } - - __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } - - __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } - - __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } - - __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } - - __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned int binary_min = 0x00800000; - static constexpr unsigned int binary_max = 0x7F7FFFFF; - static constexpr unsigned int binary_lowest = 0xFF7FFFFF; - static constexpr unsigned int binary_qnan = 0xFFC00001; - static constexpr unsigned int binary_inf = 0x7F8000000; - - __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } - - __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned short binary_min = 0x0400; - static constexpr unsigned short binary_max = 0x7BFF; - static constexpr unsigned short binary_lowest = 0xFBFF; - static constexpr unsigned short binary_qnan = 0x7FFF; - - __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } -}; - -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } - - __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } - - __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } -}; -#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 8 - static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 7 - // static constexpr uint8_t binary_min = 0x08; // 0b00001000 - // static constexpr uint8_t binary_max = 0x77; // 0b01110111 - // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - - __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 16 - static constexpr uint8_t binary_min = 0x04; // 0b00000100 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 15 - // static constexpr uint8_t binary_min = 0x04; // 0b00000100 - // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 - // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - - __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 - static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 - static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 - static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 - - __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr f8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr f8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 - static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 - static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 - static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 - - __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr bf8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr bf8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 - static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 - static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 - static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 - static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 - - static constexpr float data_max_normal_number = 6; - static constexpr float data_min_subnormal_number = 0.5; - - __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } - __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } - __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } - __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } - __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 - - static constexpr float data_max_normal_number = 7.5; - static constexpr float data_min_subnormal_number = 0.125; - - __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Lowest() - { - return f6_t(binary_lowest_normal & 0b111111); - } - __host__ __device__ static constexpr f6_t MinSubnorm() - { - return f6_t(binary_min_subnorm & 0b111111); - } - __host__ __device__ static constexpr f6_t MaxSubnorm() - { - return f6_t(binary_max_subnorm & 0b111111); - } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 - - static constexpr float data_max_normal_number = 28; - static constexpr float data_min_subnormal_number = 0.0625; - - __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } - __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } - __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } - __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } - __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 - static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 - static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 - static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 - static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 - static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 - static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 - static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 - - __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } - __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } - __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_135() - { - return e8m0_bexp_t(binary_135); - } - __host__ __device__ static constexpr e8m0_bexp_t Binary_142() - { - return e8m0_bexp_t(binary_142); - } -}; -#else -template -struct NumericLimits -{ - __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } - __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } - __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } - __host__ __device__ static constexpr T QuietNaN() - { - return std::numeric_limits::quiet_NaN(); - } - __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } -}; - -template <> -struct NumericLimits -{ - static constexpr unsigned short binary_min = 0x0400; - static constexpr unsigned short binary_max = 0x7BFF; - static constexpr unsigned short binary_lowest = 0xFBFF; - static constexpr unsigned short binary_qnan = 0x7FFF; - - __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } - - __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } -}; - -#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 -template <> -struct NumericLimits -{ - __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } - - __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } - - __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } -}; -#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 8 - static constexpr uint8_t binary_min = 0x08; // 0b00001000 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 7 - // static constexpr uint8_t binary_min = 0x08; // 0b00001000 - // static constexpr uint8_t binary_max = 0x77; // 0b01110111 - // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 - - __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - // negative zero nan mode with exp bias = 16 - static constexpr uint8_t binary_min = 0x04; // 0b00000100 - static constexpr uint8_t binary_max = 0x7F; // 0b01111111 - static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 - static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 - // ieee mode with exp bias = 15 - // static constexpr uint8_t binary_min = 0x04; // 0b00000100 - // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 - // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 - // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= - - __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } - - __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } - - __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } - - __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 - static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 - static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 - static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 - - __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr f8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr f8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 - static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 - static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 - static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 - - __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } - - __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } - - __host__ __device__ static constexpr bf8_ocp_t Lowest() - { - return bit_cast(binary_lowest); - } - - __host__ __device__ static constexpr bf8_ocp_t QuietNaN() - { - return bit_cast(binary_qnan); - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 - static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 - static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 - static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 - static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 - - static constexpr float data_max_normal_number = 6; - static constexpr float data_min_subnormal_number = 0.5; - - __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } - __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } - __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } - __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } - __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 - - static constexpr float data_max_normal_number = 7.5; - static constexpr float data_min_subnormal_number = 0.125; - - __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } - __host__ __device__ static constexpr f6_t Lowest() - { - return f6_t(binary_lowest_normal & 0b111111); - } - __host__ __device__ static constexpr f6_t MinSubnorm() - { - return f6_t(binary_min_subnorm & 0b111111); - } - __host__ __device__ static constexpr f6_t MaxSubnorm() - { - return f6_t(binary_max_subnorm & 0b111111); - } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 - static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 - static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 - static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 - static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 - - static constexpr float data_max_normal_number = 28; - static constexpr float data_min_subnormal_number = 0.0625; - - __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } - __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } - __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } - __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } - __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } - - __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } - __host__ __device__ static constexpr float DataMinSubnorm() - { - return data_min_subnormal_number; - } -}; - -template <> -struct NumericLimits -{ - static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 - static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 - static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 - static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 - static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 - static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 - static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 - static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 - - __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } - __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } - __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } - __host__ __device__ static constexpr e8m0_bexp_t Binary_135() - { - return e8m0_bexp_t(binary_135); - } - __host__ __device__ static constexpr e8m0_bexp_t Binary_142() - { - return e8m0_bexp_t(binary_142); - } -}; -#endif - -template -struct NumericUtils -{ -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 23; - static constexpr int bias = 127; - static constexpr uint32_t nan_mask = 0x7F800000; - static constexpr uint32_t head_mask = 0xFF800000; - static constexpr uint32_t mant_mask = 0x7FFFFF; - static constexpr uint32_t exp_mask = 0xFF; - static constexpr uint32_t Inf = 0x7F800000; - static constexpr uint32_t NegInf = 0xFF800000; - static constexpr uint32_t NaN = 0x7F800001; - static constexpr uint32_t Neg0 = 0x80000000; - static constexpr bool has_inf = true; - using bitwise_type = uint32_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 10; - static constexpr int bias = 15; - static constexpr uint16_t nan_mask = 0x7C00; - static constexpr uint16_t head_mask = 0xFC00; - static constexpr uint16_t mant_mask = 0x3FF; - static constexpr uint16_t exp_mask = 0x1F; - static constexpr uint32_t Inf = 0x7C00; - static constexpr uint32_t NegInf = 0xFC00; - static constexpr uint32_t NaN = 0x7C01; - static constexpr uint32_t Neg0 = 0x8000; - static constexpr bool has_inf = true; - using bitwise_type = uint16_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 7; - static constexpr int bias = 128; // negative zero nan mode - // static constexpr int bias = 127; // ieee mode -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 4; - static constexpr int mant = 3; - static constexpr int bias = 8; // negative zero nan mode - // static constexpr int bias = 7; // ieee mode - static constexpr bool has_inf = false; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 2; - static constexpr int bias = 16; // negative zero nan mode - // static constexpr int bias = 15; // ieee mode - static constexpr bool has_inf = false; -}; -template <> -struct NumericUtils -{ - static constexpr int exp = 4; - static constexpr int mant = 3; - static constexpr int bias = 7; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 5; - static constexpr int mant = 2; - static constexpr int bias = 15; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 2; - static constexpr int mant = 1; - static constexpr int bias = 1; - static constexpr uint32_t sr_shift = 10; - - static constexpr int unbiased_exp_min = 0; - static constexpr int unbiased_exp_max = 2; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 3; - - static constexpr uint8_t positive_zero_mask = 0b0000; - static constexpr uint8_t negative_zero_mask = 0b1000; - - static constexpr uint8_t one_mask = 0b0010; - static constexpr uint8_t set_sign_mask = 0b0111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b0111; - static constexpr uint8_t data_max_negative_normal_mask = 0b1111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; - - static constexpr bool has_inf = false; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 2; - static constexpr int mant = 3; - static constexpr int bias = 1; - static constexpr uint32_t sr_shift = 12; - - static constexpr int unbiased_exp_min = 0; - static constexpr int unbiased_exp_max = 2; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 3; - - static constexpr uint8_t positive_zero_mask = 0b000000; - static constexpr uint8_t negative_zero_mask = 0b100000; - - static constexpr uint8_t set_sign_mask = 0b011111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b011111; - static constexpr uint8_t data_max_negative_normal_mask = 0b111111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; - - static constexpr bool has_inf = false; - static constexpr bool has_nan = false; - static constexpr bool has_zero = true; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 3; - static constexpr int mant = 2; - static constexpr int bias = 3; - static constexpr uint32_t sr_shift = 11; - - static constexpr int unbiased_exp_min = -2; - static constexpr int unbiased_exp_max = 4; - static constexpr int biased_exp_min = 1; - static constexpr int biased_exp_max = 7; - - static constexpr uint8_t positive_zero_mask = 0b000000; - static constexpr uint8_t negative_zero_mask = 0b100000; - - static constexpr uint8_t set_sign_mask = 0b011111; - - static constexpr uint8_t data_max_positive_normal_mask = 0b011111; - static constexpr uint8_t data_max_negative_normal_mask = 0b111111; - - static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; - static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; - - static constexpr bool has_inf = false; - static constexpr bool has_nan = false; - static constexpr bool has_zero = true; - - using bitwise_type = uint8_t; -}; - -template <> -struct NumericUtils -{ - static constexpr int exp = 8; - static constexpr int mant = 0; - static constexpr int bias = 127; - - static constexpr int unbiased_exp_min = -127; - static constexpr int unbiased_exp_max = 127; - static constexpr int biased_exp_min = 0; - static constexpr int biased_exp_max = 254; - - using bitwise_type = uint8_t; -}; } // namespace ck diff --git a/include/ck/utility/dtype_fp64.hpp b/include/ck/utility/dtype_fp64.hpp new file mode 100644 index 0000000000..3c63d083ad --- /dev/null +++ b/include/ck/utility/dtype_fp64.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +namespace ck { +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; +} // namespace ck diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp new file mode 100644 index 0000000000..302ebd86b7 --- /dev/null +++ b/include/ck/utility/dtype_vector.hpp @@ -0,0 +1,2152 @@ +// SPDX-License-Identifier: MIT +// // // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +template +struct vector_type()>> +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } +}; + +__device__ int static err = 0; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + + using type = d3_t; + + union + { + d3_t d3_; + StaticallyIndexedArray d1x3_; + StaticallyIndexedArray d2x1_; + StaticallyIndexedArray d3x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d5_t __attribute__((ext_vector_type(5))); + + using type = d5_t; + + union + { + d5_t d5_; + StaticallyIndexedArray d1x5_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d5x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d7_t __attribute__((ext_vector_type(7))); + + using type = d7_t; + + union + { + d7_t d7_; + StaticallyIndexedArray d1x7_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d7x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d13_t __attribute__((ext_vector_type(13))); + + using type = d13_t; + + union + { + d13_t d13_; + StaticallyIndexedArray d1x13_; + StaticallyIndexedArray d4x3_; + StaticallyIndexedArray d8x1_; + StaticallyIndexedArray d13x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_ = {d32_t{0}}; + + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } + + // __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + // __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } +}; + +template +struct non_native_vector_base; + +template +struct nnvb_data_t_selector +{ + using type = unsigned _BitInt(8 * sizeof(T)); +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> +{ + using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = + sizeof(data_t) / sizeof(element_t); // f6x16: 12/4 = 3, f6x32: 24/4 = 6 + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) + : data_{data_v(a.At(Number<0>{}))} + { + } + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } +}; + +template +struct scalar_type>; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type> +{ + using type = typename non_native_vector_base::data_t; + + static constexpr index_t vector_size = N; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using type = d1_nnv_t; + + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + d1_nnv_t d1_nnv_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + + using type = d2_t; + + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + + using type = d4_t; + + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +using int64_t = long; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; + +// bf8 +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + +} // namespace ck diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 2533073225..799683ae65 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_utils.hpp" namespace ck { diff --git a/include/ck/utility/generic_memory_space_atomic.hpp b/include/ck/utility/generic_memory_space_atomic.hpp index 98f40a4363..ab9cc4199c 100644 --- a/include/ck/utility/generic_memory_space_atomic.hpp +++ b/include/ck/utility/generic_memory_space_atomic.hpp @@ -3,6 +3,7 @@ #pragma once #include "data_type.hpp" +#include "dtype_fp64.hpp" namespace ck { diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 05ae9093e2..7b079c541c 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck/ck.hpp" -#include "data_type.hpp" +#include "numeric_limits.hpp" #include "integral_constant.hpp" #include "number.hpp" #include "type.hpp" diff --git a/include/ck/utility/mxf4_utils.hpp b/include/ck/utility/mxf4_utils.hpp index 757d3914e3..72a0bb919c 100644 --- a/include/ck/utility/mxf4_utils.hpp +++ b/include/ck/utility/mxf4_utils.hpp @@ -4,7 +4,7 @@ #ifndef CK_CODE_GEN_RTC #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck::utils { diff --git a/include/ck/utility/mxf6_utils.hpp b/include/ck/utility/mxf6_utils.hpp index 00b4f8e5d4..cf68188b3e 100644 --- a/include/ck/utility/mxf6_utils.hpp +++ b/include/ck/utility/mxf6_utils.hpp @@ -4,7 +4,7 @@ #ifndef CK_CODE_GEN_RTC #pragma once -#include "ck/utility/data_type.hpp" +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" namespace ck::utils { diff --git a/include/ck/utility/mxf8_utils.hpp b/include/ck/utility/mxf8_utils.hpp index 2dbf997f6a..b7b98c6455 100644 --- a/include/ck/utility/mxf8_utils.hpp +++ b/include/ck/utility/mxf8_utils.hpp @@ -1,4 +1,7 @@ -#include "ck/utility/data_type.hpp" +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/utility/numeric_limits.hpp" #include "ck/utility/mxfp_utils.hpp" #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ diff --git a/include/ck/utility/numeric_limits.hpp b/include/ck/utility/numeric_limits.hpp new file mode 100644 index 0000000000..e59b7eceaf --- /dev/null +++ b/include/ck/utility/numeric_limits.hpp @@ -0,0 +1,555 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +template +struct NumericLimits; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; } + + __host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; } + + __host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int32_t QuietNaN() { return 0; } +}; +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Min() noexcept { return -32768; } + + __host__ __device__ static constexpr int16_t Max() noexcept { return 32767; } + + __host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Min() noexcept { return -128; } + + __host__ __device__ static constexpr int8_t Max() noexcept { return 127; } + + __host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr int8_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; } + + __host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint32_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Min() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; } + + __host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; } + + __host__ __device__ static constexpr uint16_t QuietNaN() { return 0; } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned int binary_min = 0x00800000; + static constexpr unsigned int binary_max = 0x7F7FFFFF; + static constexpr unsigned int binary_lowest = 0xFF7FFFFF; + static constexpr unsigned int binary_qnan = 0xFFC00001; + static constexpr unsigned int binary_inf = 0x7F800000; + + __host__ __device__ static constexpr float Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr float Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr float Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr float QuietNaN() { return bit_cast(binary_qnan); } + + __host__ __device__ static constexpr float Infinity() { return bit_cast(binary_inf); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +#else +template +struct NumericLimits +{ + __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } + __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } + __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } + __host__ __device__ static constexpr T QuietNaN() + { + return std::numeric_limits::quiet_NaN(); + } + __host__ __device__ static constexpr T Infinity() { return std::numeric_limits::infinity(); } +}; + +template <> +struct NumericLimits +{ + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; + static constexpr unsigned short binary_qnan = 0x7FFF; + + __host__ __device__ static constexpr half_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return bit_cast(binary_lowest); } + + __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast(binary_qnan); } +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct NumericLimits +{ + __host__ __device__ static constexpr int4_t Min() { return int4_t(-8); } + + __host__ __device__ static constexpr int4_t Max() { return int4_t(7); } + + __host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); } +}; +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 8 + static constexpr uint8_t binary_min = 0x08; // 0b00001000 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 7 + // static constexpr uint8_t binary_min = 0x08; // 0b00001000 + // static constexpr uint8_t binary_max = 0x77; // 0b01110111 + // static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0 + + __host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + // negative zero nan mode with exp bias = 16 + static constexpr uint8_t binary_min = 0x04; // 0b00000100 + static constexpr uint8_t binary_max = 0x7F; // 0b01111111 + static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111 + static constexpr uint8_t binary_qnan = 0x80; // 0b10000000 + // ieee mode with exp bias = 15 + // static constexpr uint8_t binary_min = 0x04; // 0b00000100 + // static constexpr uint8_t binary_max = 0x7B; // 0b01111011 + // static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 + // static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!= + + __host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); } + + __host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); } + + __host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); } + + __host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6 + static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448 + static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448 + static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111 + + __host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr f8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr f8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14 + static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344 + static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344 + static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101 + + __host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast(binary_min); } + + __host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast(binary_max); } + + __host__ __device__ static constexpr bf8_ocp_t Lowest() + { + return bit_cast(binary_lowest); + } + + __host__ __device__ static constexpr bf8_ocp_t QuietNaN() + { + return bit_cast(binary_qnan); + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x2; // 0b0010 + static constexpr uint8_t binary_max_normal = 0x7; // 0b0111 + static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111 + static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001 + static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001 + + static constexpr float data_max_normal_number = 6; + static constexpr float data_min_subnormal_number = 0.5; + + __host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); } + __host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); } + __host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); } + __host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); } + __host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111 + + static constexpr float data_max_normal_number = 7.5; + static constexpr float data_min_subnormal_number = 0.125; + + __host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); } + __host__ __device__ static constexpr f6_t Lowest() + { + return f6_t(binary_lowest_normal & 0b111111); + } + __host__ __device__ static constexpr f6_t MinSubnorm() + { + return f6_t(binary_min_subnorm & 0b111111); + } + __host__ __device__ static constexpr f6_t MaxSubnorm() + { + return f6_t(binary_max_subnorm & 0b111111); + } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +template <> +struct NumericLimits +{ + static constexpr uint8_t binary_min_normal = 0x08; // 0b001000 + static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111 + static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111 + static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001 + static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011 + + static constexpr float data_max_normal_number = 28; + static constexpr float data_min_subnormal_number = 0.0625; + + __host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); } + __host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); } + __host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); } + __host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); } + __host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); } + + __host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; } + __host__ __device__ static constexpr float DataMinSubnorm() + { + return data_min_subnormal_number; + } +}; + +#endif + +template <> +struct NumericLimits +{ + static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000 + static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110 + static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111 + static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111 + static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000 + static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010 + static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111 + static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110 + + __host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); } + __host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); } + __host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); } + __host__ __device__ static constexpr e8m0_bexp_t Binary_135() + { + return e8m0_bexp_t(binary_135); + } + __host__ __device__ static constexpr e8m0_bexp_t Binary_142() + { + return e8m0_bexp_t(binary_142); + } +}; + +} // namespace ck diff --git a/include/ck/utility/numeric_utils.hpp b/include/ck/utility/numeric_utils.hpp new file mode 100644 index 0000000000..726f667518 --- /dev/null +++ b/include/ck/utility/numeric_utils.hpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +template +struct NumericUtils +{ +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 0; + static constexpr int bias = 127; + + static constexpr int unbiased_exp_min = -127; + static constexpr int unbiased_exp_max = 127; + static constexpr int biased_exp_min = 0; + static constexpr int biased_exp_max = 254; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + static constexpr bool has_inf = true; + using bitwise_type = uint32_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 10; + static constexpr int bias = 15; + static constexpr uint16_t nan_mask = 0x7C00; + static constexpr uint16_t head_mask = 0xFC00; + static constexpr uint16_t mant_mask = 0x3FF; + static constexpr uint16_t exp_mask = 0x1F; + static constexpr uint32_t Inf = 0x7C00; + static constexpr uint32_t NegInf = 0xFC00; + static constexpr uint32_t NaN = 0x7C01; + static constexpr uint32_t Neg0 = 0x8000; + static constexpr bool has_inf = true; + using bitwise_type = uint16_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 8; + static constexpr int mant = 7; + static constexpr int bias = 128; // negative zero nan mode + // static constexpr int bias = 127; // ieee mode +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 8; // negative zero nan mode + // static constexpr int bias = 7; // ieee mode + static constexpr bool has_inf = false; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 16; // negative zero nan mode + // static constexpr int bias = 15; // ieee mode + static constexpr bool has_inf = false; +}; +template <> +struct NumericUtils +{ + static constexpr int exp = 4; + static constexpr int mant = 3; + static constexpr int bias = 7; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 5; + static constexpr int mant = 2; + static constexpr int bias = 15; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 10; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b0000; + static constexpr uint8_t negative_zero_mask = 0b1000; + + static constexpr uint8_t one_mask = 0b0010; + static constexpr uint8_t set_sign_mask = 0b0111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b0111; + static constexpr uint8_t data_max_negative_normal_mask = 0b1111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001; + + static constexpr bool has_inf = false; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 2; + static constexpr int mant = 3; + static constexpr int bias = 1; + static constexpr uint32_t sr_shift = 12; + + static constexpr int unbiased_exp_min = 0; + static constexpr int unbiased_exp_max = 2; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 3; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; + +template <> +struct NumericUtils +{ + static constexpr int exp = 3; + static constexpr int mant = 2; + static constexpr int bias = 3; + static constexpr uint32_t sr_shift = 11; + + static constexpr int unbiased_exp_min = -2; + static constexpr int unbiased_exp_max = 4; + static constexpr int biased_exp_min = 1; + static constexpr int biased_exp_max = 7; + + static constexpr uint8_t positive_zero_mask = 0b000000; + static constexpr uint8_t negative_zero_mask = 0b100000; + + static constexpr uint8_t set_sign_mask = 0b011111; + + static constexpr uint8_t data_max_positive_normal_mask = 0b011111; + static constexpr uint8_t data_max_negative_normal_mask = 0b111111; + + static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011; + static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011; + + static constexpr bool has_inf = false; + static constexpr bool has_nan = false; + static constexpr bool has_zero = true; + + using bitwise_type = uint8_t; +}; +} // namespace ck