From f55c7629bc6e6aab79a9700dd323743729e2d6e3 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 Mar 2024 23:23:32 +0000 Subject: [PATCH] not using custom data type by default, now we can have ISA-level same code as opt_padding --- example/ck_tile/01_fmha/generate.py | 2 +- include/ck_tile/core.hpp | 3 +- .../core/algorithm/coordinate_transform.hpp | 16 +- .../core/arch/amd_buffer_addressing.hpp | 4 +- include/ck_tile/core/config.hpp | 4 + .../ck_tile/core/container/thread_buffer.hpp | 68 +++++ include/ck_tile/core/numeric/bfloat16.hpp | 115 +++++--- include/ck_tile/core/numeric/float8.hpp | 262 ++++++++++++------ include/ck_tile/core/numeric/half.hpp | 125 ++++++--- .../numeric/{arithmetic.hpp => numeric.hpp} | 96 ++++++- include/ck_tile/core/numeric/type_convert.hpp | 24 +- include/ck_tile/core/numeric/vector_type.hpp | 25 +- include/ck_tile/core/tensor/buffer_view.hpp | 19 +- .../ck_tile/core/tensor/tensor_adaptor.hpp | 6 +- .../ck_tile/core/tensor/tile_elementwise.hpp | 5 +- include/ck_tile/core/tensor/tile_window.hpp | 9 +- include/ck_tile/core/utility/limits.hpp | 75 ----- .../reference/reference_batched_masking.hpp | 2 +- .../reference/reference_batched_softmax.hpp | 2 +- .../host/reference/reference_softmax.hpp | 2 +- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 2 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 23 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 23 +- .../block_fmha_pipeline_qr_ks_vs_fp8.hpp | 21 +- .../pipeline/block_fmha_pipeline_qs_ks_vs.hpp | 23 +- include/ck_tile/ops/gemm.hpp | 1 - 26 files changed, 629 insertions(+), 328 deletions(-) rename include/ck_tile/core/numeric/{arithmetic.hpp => numeric.hpp} (69%) delete mode 100644 include/ck_tile/core/utility/limits.hpp diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index f2b7a61c17..5c44ad303b 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -11,7 +11,7 @@ import copy import fnmatch DTYPE_MAP = { - "fp16": "ck_tile::half_t", + "fp16": "ck_tile::fp16_t", "bf16": "ck_tile::bf16_t", "fp8" : "ck_tile::fp8_t" } diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index daf5a12d2d..9ac55c1197 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -20,13 +20,13 @@ #include "ck_tile/core/container/statically_indexed_array.hpp" #include "ck_tile/core/container/thread_buffer.hpp" #include "ck_tile/core/container/tuple.hpp" -#include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/tensor/buffer_view.hpp" @@ -49,7 +49,6 @@ #include "ck_tile/core/tensor/tile_window.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" -#include "ck_tile/core/utility/limits.hpp" #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index ad7054aabd..71602e5d13 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -47,8 +47,8 @@ struct base_transform { if constexpr(NDimUp > 0) { - array up_vector_lengths = make_array_with({-1}); - array up_vector_strides = make_array_with({-1}); + array up_vector_lengths{-1}; + array up_vector_strides{-1}; return make_tuple(up_vector_lengths, up_vector_strides); } @@ -690,8 +690,8 @@ struct merge_v2_magic_division : public base_transform calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, const LowVectorStrides& low_vector_strides) { - array up_vector_lengths = make_array_with({-1}); - array up_vector_strides = make_array_with({-1}); + array up_vector_lengths{-1}; + array up_vector_strides{-1}; up_vector_lengths[0] = low_vector_lengths[number{}]; up_vector_strides[0] = low_vector_strides[number{}]; @@ -821,8 +821,8 @@ struct merge_v3_division_mod : public base_transform calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, const LowVectorStrides& low_vector_strides) { - array up_vector_lengths = make_array_with({-1}); - array up_vector_strides = make_array_with({-1}); + array up_vector_lengths{-1}; + array up_vector_strides{-1}; up_vector_lengths[0] = low_vector_lengths[number{}]; up_vector_strides[0] = low_vector_strides[number{}]; @@ -940,8 +940,8 @@ struct unmerge : public base_transform<1, UpLengths::size()> calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths, const LowVectorStrides& low_vector_strides) { - array up_vector_lengths = make_array_with({-1}); - array up_vector_strides = make_array_with({-1}); + array up_vector_lengths{-1}; + array up_vector_strides{-1}; constexpr auto up_length_last = UpLengths{}[number{}]; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index c37af77ad4..53f42a7421 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -414,7 +414,7 @@ struct buffer_store_if<8> static_assert(sizeof(T) == 8); auto save_exec = __builtin_amdgcn_read_exec(); // TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch - using mbuf_t = ext_vector_t; + using mbuf_t = ext_vector_t; asm volatile("v_cmpx_le_u32 exec, 1, %5\n" "buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n" "s_mov_b64 exec %6" @@ -1778,7 +1778,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, thread_buffer tmp = amd_buffer_load_impl(src_wave_buffer_resource, src_thread_addr_offset, 0); if constexpr(oob_conditional_check) - return src_thread_element_valid ? tmp : thread_buffer{0}; + return src_thread_element_valid ? tmp : thread_buffer{numeric::zero()}; else return tmp; #endif diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 4688356ff1..d915df6e4c 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -20,6 +20,10 @@ #define CK_TILE_DEVICE_EXTERN #endif +#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE +#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code +#endif + #define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1 #define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2 diff --git a/include/ck_tile/core/container/thread_buffer.hpp b/include/ck_tile/core/container/thread_buffer.hpp index 7b8895a953..3c3894c148 100644 --- a/include/ck_tile/core/container/thread_buffer.hpp +++ b/include/ck_tile/core/container/thread_buffer.hpp @@ -19,6 +19,8 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) return make_tuple(ts...); } #else + +#if 0 template using thread_buffer = array; @@ -27,6 +29,72 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts) { return make_array(ts...); } + +#endif + +// clang-format off +template +struct thread_buffer { + using value_type = remove_cvref_t; + static constexpr index_t N = N_; + + value_type data[N]; + + // TODO: this ctor can't ignore + CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {} + CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {} + + CK_TILE_HOST_DEVICE static constexpr auto size() { return N; } + CK_TILE_HOST_DEVICE auto & get() {return data; } + CK_TILE_HOST_DEVICE const auto & get() const {return data; } + CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; } + CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; } + CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); } + CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible + CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); } + CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); } + template CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); } + template CK_TILE_HOST_DEVICE constexpr auto& at(number) { return get(I); } + template CK_TILE_HOST_DEVICE constexpr const auto& at(number) const { return get(I); } + +#define TB_COMMON_AS() \ + static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \ + constexpr int vx = sizeof(value_type) * N / sizeof(Tx) + + template + CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS(); + return reinterpret_cast&>(data);} + template + CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS(); + return reinterpret_cast&>(data);} + template + CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS(); + return reinterpret_cast&>(data).get(i);} + template + CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS(); + return reinterpret_cast&>(data).get(i);} + + template CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x) + { TB_COMMON_AS(); reinterpret_cast&>(data).at(i) = x; } + template CK_TILE_HOST_DEVICE constexpr void set_as(number, const Tx & x) + { TB_COMMON_AS(); reinterpret_cast&>(data).at(number{}) = x; } +#undef TB_COMMON_AS +}; +// clang-format on + +template +struct vector_traits; + +// specialization for array +template +struct vector_traits> +{ + using scalar_type = T; + static constexpr index_t vector_size = N; +}; + #endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index abcc8fdc1a..8ac9545633 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -3,10 +3,9 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" -#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include #pragma once @@ -22,18 +21,19 @@ enum class bf16_rounding_mode template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> -CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant = {}); +CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant = {}); template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> -CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant = {}); +CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant = {}); CK_TILE_HOST_DEVICE -float bf16_to_float_raw(uint16_t x); +constexpr float bf16_to_float_raw(uint16_t x); CK_TILE_HOST_DEVICE -double bf16_to_double_raw(uint16_t x); +constexpr double bf16_to_double_raw(uint16_t x); +#if CK_TILE_USE_CUSTOM_DATA_TYPE // HIP use __hip_bfloat16 as struct struct alignas(2) bfloat16_t { @@ -89,13 +89,24 @@ struct alignas(2) bfloat16_t CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } }; +template +struct native_t; +template <> +struct native_t +{ + using type = ushort; +}; using bf16_t = bfloat16_t; using bf16_raw_t = typename bf16_t::raw_type; - +#else +using bfloat16_t = ushort; +using bf16_t = bfloat16_t; +using bf16_raw_t = uint16_t; +#endif // round to nearest CK_TILE_HOST_DEVICE -uint16_t float_to_bf16_rtn_raw(float f) +constexpr uint16_t float_to_bf16_rtn_raw(float f) { union { @@ -139,7 +150,7 @@ uint16_t float_to_bf16_rtn_raw(float f) // Truncate instead of rounding, preserving SNaN CK_TILE_HOST_DEVICE -uint16_t float_to_bf16_truc_nan_raw(float f) +constexpr uint16_t float_to_bf16_truc_nan_raw(float f) { union { @@ -151,7 +162,7 @@ uint16_t float_to_bf16_truc_nan_raw(float f) // Fast truncate instead of rounding, RTZ CK_TILE_HOST_DEVICE -uint16_t float_to_bf16_truc_raw(float f) +constexpr uint16_t float_to_bf16_truc_raw(float f) { union { @@ -162,7 +173,7 @@ uint16_t float_to_bf16_truc_raw(float f) } template -CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant) +CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant) { if constexpr(rounding == bf16_rounding_mode::standard) return float_to_bf16_rtn_raw(f); @@ -173,13 +184,13 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant) } template -CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant) +CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant) { return float_to_bf16_raw(static_cast(f), constant{}); } CK_TILE_HOST_DEVICE -float bf16_to_float_raw(uint16_t x) +constexpr float bf16_to_float_raw(uint16_t x) { union { @@ -190,100 +201,118 @@ float bf16_to_float_raw(uint16_t x) } CK_TILE_HOST_DEVICE -double bf16_to_double_raw(uint16_t x) { return static_cast(bf16_to_float_raw(x)); } - -template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> -CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant) +constexpr double bf16_to_double_raw(uint16_t x) { - return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant{})); + return static_cast(bf16_to_float_raw(x)); } template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> -CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant) +CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { - return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant{})); + return bit_cast(float_to_bf16_raw(f, constant{})); } -CK_TILE_HOST_DEVICE -float bf16_to_float(bfloat16_t x) { return static_cast(x); } - -CK_TILE_HOST_DEVICE -double bf16_to_double(bfloat16_t x) { return static_cast(x); } - template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> -CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant = {}) +CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant = {}) { - return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast(f), constant{})); + return bit_cast(double_to_bf16_raw(f, constant{})); } CK_TILE_HOST_DEVICE -half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast(x)); } +constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast(x)); } + +CK_TILE_HOST_DEVICE +constexpr double bf16_to_double(bfloat16_t x) { return static_cast(bf16_to_float_raw(x)); } + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant = {}) +{ + return bit_cast(float_to_bf16_raw(static_cast(f), constant{})); +} + +CK_TILE_HOST_DEVICE +constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast(static_cast(x)); } template -struct numeric_limits; +struct numeric; template <> -struct numeric_limits +struct numeric { // minimum finite value, or minimum positive normalized value for float - CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); } + CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() + { + return bit_cast(static_cast(0x0080)); + } // minumum finite value CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest() { - return bfloat16_t::bit_cast(0xff7f); + return bit_cast(static_cast(0xff7f)); } // maximum finite value - CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); } + CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() + { + return bit_cast(static_cast(0x7f7f)); + } // difference between 1.0 and next value representable by float CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon() { - return bfloat16_t::bit_cast(0x1000); + return bit_cast(static_cast(0x1000)); } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); } + CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); } // positive infinity value CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity() { - return bfloat16_t::bit_cast(0x7f80); + return bit_cast(static_cast(0x7f80)); } // quiet NaN CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN() { - return bfloat16_t::bit_cast(0x7FFF); + return bit_cast(static_cast(0x7FFF)); } // signaling NaN CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN() { - return bfloat16_t::bit_cast(0x7FFF); + return bit_cast(static_cast(0x7FFF)); } // smallest positive subnormal value CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min() { - return bfloat16_t::bit_cast(0x0001); + return bit_cast(static_cast(0x0001)); + } + CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero() + { + return bit_cast(static_cast(0)); } }; +#if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t) +#endif // math CK_TILE_HOST_DEVICE -bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); } +bfloat16_t abs(const bfloat16_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7fff)); +} CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t& x) { - uint16_t xx = x.get(); + uint16_t xx = bit_cast(x); return (xx & 0x7FFF) > 0x7C00; } diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index e94c9e7764..05f65309d9 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -3,13 +3,12 @@ #include "ck_tile/core/config.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include "ck_tile/core/utility/random.hpp" -#include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" -#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include #include @@ -43,14 +42,15 @@ enum class fp8_rounding_mode */ template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant = {}); +CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant = {}); template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant = {}); +CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant = {}); -CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t); -CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t); +#if CK_TILE_USE_CUSTOM_DATA_TYPE struct alignas(1) float8_e4m3_t { static constexpr int exponent = 4; @@ -58,7 +58,7 @@ struct alignas(1) float8_e4m3_t #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else - static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE #endif using raw_type = uint8_t; raw_type data; @@ -116,7 +116,7 @@ struct alignas(1) float8_e5m2_t #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) static constexpr int bias = 1 << (exponent - 1); // NANOO #else - static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE + static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE #endif using raw_type = uint8_t; raw_type data; @@ -167,6 +167,28 @@ struct alignas(1) float8_e5m2_t using bf8_t = float8_e5m2_t; using bf8_raw_t = typename bf8_t::raw_type; +template +struct native_t; + +template <> +struct native_t +{ + using type = _BitInt(8); +}; + +template <> +struct native_t +{ + using type = unsigned _BitInt(8); +}; + +#else +using fp8_t = _BitInt(8); +using fp8_raw_t = uint8_t; +using bf8_t = unsigned _BitInt(8); +using bf8_raw_t = uint8_t; +#endif + // below is sw fp8 conversion, not utilizing hw instruction namespace impl { @@ -174,29 +196,35 @@ template CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng) { // fp8/bf8 exponent/mantissa layout - constexpr int out_exp = numeric_utils::exp; - constexpr int out_mant = numeric_utils::mant; + constexpr int out_exp = numeric_traits::exp; + constexpr int out_mant = numeric_traits::mant; // original type exponent/mantissa layout - constexpr int in_exp = numeric_utils::exp; - constexpr int in_mant = numeric_utils::mant; + constexpr int in_exp = numeric_traits::exp; + constexpr int in_mant = numeric_traits::mant; int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr Y nan_code = __builtin_bit_cast(Y, static_cast(0x80)); - constexpr uint32_t nan_mask = numeric_utils::nan_mask; +#if CK_TILE_USE_CUSTOM_DATA_TYPE + constexpr Y nan_code = + numeric::quiet_NaN(); // __builtin_bit_cast(Y, static_cast(0x80)); +#else + constexpr Y nan_code = 0x80; +#endif + + constexpr uint32_t nan_mask = numeric_traits::nan_mask; // convert to bitwise - using T_bitwise = typename numeric_utils::bitwise_type; + using T_bitwise = typename numeric_traits::bitwise_type; T_bitwise x_bitwise = *(reinterpret_cast(&x)); // unpack the input, depends on datatype - head = x_bitwise & numeric_utils::head_mask; - mantissa = x_bitwise & numeric_utils::mant_mask; - exponent = (head >> in_mant) & numeric_utils::exp_mask; + head = x_bitwise & numeric_traits::head_mask; + mantissa = x_bitwise & numeric_traits::mant_mask; + exponent = (head >> in_mant) & numeric_traits::exp_mask; sign = head >> (in_exp + in_mant); - bias = numeric_utils::bias; + bias = numeric_traits::bias; uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; @@ -335,23 +363,23 @@ template CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) { // fp8/bf8 exponent/mantissa layout - constexpr int in_exp = numeric_utils::exp; - constexpr int in_mant = numeric_utils::mant; + constexpr int in_exp = numeric_traits::exp; + constexpr int in_mant = numeric_traits::mant; // resulting type exponent/mantissa layout - constexpr int out_exp = numeric_utils::exp; - constexpr int out_mant = numeric_utils::mant; + constexpr int out_exp = numeric_traits::exp; + constexpr int out_mant = numeric_traits::mant; uint8_t x_raw = __builtin_bit_cast(uint8_t, x); // prepare the codes constexpr uint8_t nan_code = 0x80; Y Inf, NegInf, NaN, Neg0; - using T_bitwise = typename numeric_utils::bitwise_type; + using T_bitwise = typename numeric_traits::bitwise_type; - constexpr T_bitwise Inf_bitwise = numeric_utils::Inf; - constexpr T_bitwise NegInf_bitwise = numeric_utils::NegInf; - constexpr T_bitwise NaN_bitwise = numeric_utils::NaN; - constexpr T_bitwise Neg0_bitwise = numeric_utils::Neg0; + constexpr T_bitwise Inf_bitwise = numeric_traits::Inf; + constexpr T_bitwise NegInf_bitwise = numeric_traits::NegInf; + constexpr T_bitwise NaN_bitwise = numeric_traits::NaN; + constexpr T_bitwise Neg0_bitwise = numeric_traits::Neg0; Inf = *(reinterpret_cast(&Inf_bitwise)); NegInf = *(reinterpret_cast(&NegInf_bitwise)); @@ -384,7 +412,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; } - if((numeric_utils::mant == 10) && (numeric_utils::mant == 2) && !negative_zero_nan) + if((numeric_traits::mant == 10) && (numeric_traits::mant == 2) && !negative_zero_nan) { retval = x_raw; retval <<= 8; @@ -553,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) // clang-format off template -CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) +CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); @@ -561,14 +589,14 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) } template -CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) +CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); else return bf8_raw_t{0}; } -CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) +CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -578,11 +606,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) return fval; #else constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(fp8_t::bit_cast(x)); + return impl::cast_from_f8(bit_cast(x)); #endif } -CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) +CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -592,129 +620,200 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) return fval; #else constexpr bool negative_zero_nan = true; - return impl::cast_from_f8(bf8_t::bit_cast(x)); + return impl::cast_from_f8(bit_cast(x)); #endif } -template -CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant) +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant = {}) { - return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant{})); + return bit_cast(float_to_fp8_raw(x, constant{})); } -template -CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant) +template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant = {}) { - return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant{})); + return bit_cast(float_to_bf8_raw(x, constant{})); } -CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x) +CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x) { - return fp8_to_float_raw(x.get()); + return fp8_to_float_raw(bit_cast(x)); } -CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x) +CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x) { - return bf8_to_float_raw(x.get()); + return bf8_to_float_raw(bit_cast(x)); } // clang-format on template -struct numeric_utils; +struct numeric_traits; template <> -struct numeric_utils +struct numeric_traits { - static constexpr int exp = fp8_t::exponent; - static constexpr int mant = fp8_t::mantissa; - static constexpr int bias = fp8_t::bias; + static constexpr int exp = 4; + static constexpr int mant = 3; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 8; +#else + static constexpr int bias = 7; +#endif }; template <> -struct numeric_utils +struct numeric_traits { - static constexpr int exp = bf8_t::exponent; - static constexpr int mant = bf8_t::mantissa; - static constexpr int bias = bf8_t::bias; + static constexpr int exp = 5; + static constexpr int mant = 2; +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) + static constexpr int bias = 16; +#else + static constexpr int bias = 15; // IEEE +#endif }; template -struct numeric_limits; +struct numeric; template <> -struct numeric_limits +struct numeric { // minimum finite value, or minimum positive normalized value for float - CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); } + CK_TILE_HOST_DEVICE static constexpr fp8_t min() + { + return bit_cast(static_cast(0x08)); + } // minumum finite value - CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); } + CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() + { + return bit_cast(static_cast(0xff)); + } // maximum finite value - CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); } + CK_TILE_HOST_DEVICE static constexpr fp8_t max() + { + return bit_cast(static_cast(0x7f)); + } // difference between 1.0 and next value representable by float - CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); } + CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() + { + return bit_cast(static_cast(0x20)); + } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); } + CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); } // positive infinity value - CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() + { + return bit_cast(static_cast(0x80)); + } // quiet NaN - CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() + { + return bit_cast(static_cast(0x80)); + } // signaling NaN - CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() + { + return bit_cast(static_cast(0x80)); + } // smallest positive subnormal value - CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); } + CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr fp8_t zero() + { + return bit_cast(static_cast(0)); + } }; template <> -struct numeric_limits +struct numeric { // minimum finite value, or minimum positive normalized value for float - CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); } + CK_TILE_HOST_DEVICE static constexpr bf8_t min() + { + return bit_cast(static_cast(0x04)); + } // minumum finite value - CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); } + CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() + { + return bit_cast(static_cast(0xff)); + } // maximum finite value - CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); } + CK_TILE_HOST_DEVICE static constexpr bf8_t max() + { + return bit_cast(static_cast(0x7f)); + } // difference between 1.0 and next value representable by float - CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); } + CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() + { + return bit_cast(static_cast(0x34)); + } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); } + CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); } // positive infinity value - CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() + { + return bit_cast(static_cast(0x80)); + } // quiet NaN - CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() + { + return bit_cast(static_cast(0x80)); + } // signaling NaN - CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); } + CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() + { + return bit_cast(static_cast(0x80)); + } // smallest positive subnormal value - CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); } + CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() + { + return bit_cast(static_cast(0x01)); + } + + CK_TILE_HOST_DEVICE static constexpr bf8_t zero() + { + return bit_cast(static_cast(0)); + } }; +#if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t) CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t) +#endif // math CK_TILE_HOST_DEVICE -fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); } +fp8_t abs(const fp8_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7f)); +} CK_TILE_HOST_DEVICE bool isnan(const fp8_t& x) { - uint8_t xx = x.get(); + uint8_t xx = bit_cast(x); return xx == 0x80; // TODO: NANOO } @@ -731,12 +830,15 @@ CK_TILE_DEVICE fp8_t log(fp8_t x) { return static_cast(__logf(static_cast(x))); }; CK_TILE_HOST_DEVICE -bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); } +bf8_t abs(const bf8_t& x) +{ + return bit_cast(static_cast(bit_cast(x) & 0x7f)); +} CK_TILE_HOST_DEVICE bool isnan(const bf8_t& x) { - uint8_t xx = x.get(); + uint8_t xx = bit_cast(x); return xx == 0x80; // TODO: NANOO } diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 60ef6c978e..dfe1d6461c 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -2,33 +2,34 @@ // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #include "ck_tile/core/config.hpp" -#include "ck_tile/core/numeric/arithmetic.hpp" #include "ck_tile/core/utility/bit_cast.hpp" -#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" #include #pragma once namespace ck_tile { -using fp16_hip_t = __half; // most of hip internal function use this type +using fp16_hip_t = _Float16; // most of hip internal function use this type +using fp16_raw_t = uint16_t; CK_TILE_HOST_DEVICE -float fp16_to_float_hip(const fp16_hip_t& x); +constexpr float fp16_to_float_hip(const fp16_hip_t& x); CK_TILE_HOST_DEVICE -double fp16_to_double_hip(const fp16_hip_t& x); +constexpr double fp16_to_double_hip(const fp16_hip_t& x); CK_TILE_HOST_DEVICE -fp16_hip_t float_to_fp16_hip(const float& x); +constexpr fp16_hip_t float_to_fp16_hip(const float& x); CK_TILE_HOST_DEVICE -fp16_hip_t double_to_fp16_hip(const double& x); +constexpr fp16_hip_t double_to_fp16_hip(const double& x); +#if CK_TILE_USE_CUSTOM_DATA_TYPE // HIP use fp16_hip_t as interchangable data type for float16 struct alignas(2) half_t { - using raw_type = uint16_t; + using raw_type = fp16_raw_t; raw_type data; CK_TILE_HOST_DEVICE @@ -83,6 +84,9 @@ struct alignas(2) half_t return static_cast(fp16_to_float_hip(to_fp16())); } + CK_TILE_HOST_DEVICE + explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast(data); } + // internal access CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } @@ -91,86 +95,132 @@ struct alignas(2) half_t constexpr raw_type get() const { return data; } }; +template +struct native_t; + +template <> +struct native_t +{ + using type = _Float16; +}; + using fp16_t = half_t; -using fp16_raw_t = typename fp16_t::raw_type; +using fp16_raw_t = typename half_t::raw_type; +#else +using fp16_t = _Float16; +using half_t = _Float16; +using fp16_raw_t = ushort; +#endif // conversions CK_TILE_HOST_DEVICE -float fp16_to_float_hip(const fp16_hip_t& x) +constexpr float fp16_to_float_hip(const fp16_hip_t& x) { // return __half2float(x); return static_cast(x); } CK_TILE_HOST_DEVICE -double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast(fp16_to_float_hip(x)); } +constexpr double fp16_to_double_hip(const fp16_hip_t& x) +{ + return static_cast(fp16_to_float_hip(x)); +} CK_TILE_HOST_DEVICE -fp16_hip_t float_to_fp16_hip(const float& x) +constexpr fp16_hip_t float_to_fp16_hip(const float& x) { return __float2half(x); // return static_cast(x); } CK_TILE_HOST_DEVICE -fp16_hip_t double_to_fp16_hip(const double& x) +constexpr fp16_hip_t double_to_fp16_hip(const double& x) { // return __float2half(x); return static_cast(x); } CK_TILE_HOST_DEVICE -float fp16_to_float(const half_t& x) { return static_cast(x); } +constexpr float fp16_to_float(const half_t& x) { return static_cast(x); } CK_TILE_HOST_DEVICE -float fp16_to_double(const half_t& x) { return static_cast(x); } +constexpr float fp16_to_double(const half_t& x) { return static_cast(x); } CK_TILE_HOST_DEVICE -half_t float_to_fp16(const float& x) { return half_t{x}; } +constexpr half_t float_to_fp16(const float& x) { return static_cast(x); } CK_TILE_HOST_DEVICE -half_t double_to_fp16(const double& x) { return half_t{x}; } +constexpr half_t double_to_fp16(const double& x) { return static_cast(x); } // limits template -struct numeric_limits; +struct numeric; template <> -struct numeric_limits +struct numeric { // minimum finite value, or minimum positive normalized value for float - CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); } + CK_TILE_HOST_DEVICE static constexpr half_t min() + { + return bit_cast(static_cast(0x0400)); + } // minumum finite value - CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); } + CK_TILE_HOST_DEVICE static constexpr half_t lowest() + { + return bit_cast(static_cast(0xFBFF)); + } // maximum finite value - CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); } + CK_TILE_HOST_DEVICE static constexpr half_t max() + { + return bit_cast(static_cast(0x7BFF)); + } // difference between 1.0 and next value representable by float - CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); } + CK_TILE_HOST_DEVICE static constexpr half_t epsilon() + { + return bit_cast(static_cast(0x1800)); + } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); } + CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast(0.5f); } // positive infinity value - CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); } + CK_TILE_HOST_DEVICE static constexpr half_t infinity() + { + return bit_cast(static_cast(0x7C00)); + } // quiet NaN - CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); } + CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } // signaling NaN - CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); } + CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() + { + return bit_cast(static_cast(0x7FFF)); + } // smallest positive subnormal value - CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); } + CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() + { + return bit_cast(static_cast(0x0001)); + } + + CK_TILE_HOST_DEVICE static constexpr half_t zero() + { + return bit_cast(static_cast(0)); + } }; template -struct numeric_utils; +struct numeric_traits; template <> -struct numeric_utils +struct numeric_traits { static constexpr int exp = 5; static constexpr int mant = 10; @@ -186,9 +236,12 @@ struct numeric_utils using bitwise_type = uint16_t; }; +#if CK_TILE_USE_CUSTOM_DATA_TYPE // arithmetic -CK_TILE_DEVICE -bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); } +CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y) +{ + return __heq(x.to_fp16(), y.to_fp16()); +} CK_TILE_DEVICE bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); } @@ -205,6 +258,7 @@ bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.t CK_TILE_DEVICE bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); } +#if 0 CK_TILE_DEVICE half_t operator+(const half_t& x, const half_t& y) { @@ -289,12 +343,15 @@ half_t operator--(half_t& x, int) x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16())); return y; } +#endif +#if CK_TILE_USE_CUSTOM_DATA_TYPE CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t) +#endif // math CK_TILE_HOST_DEVICE -half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); } +half_t abs(const half_t& x) { return bit_cast(x.get() & 0x7fff); } CK_TILE_HOST_DEVICE bool isnan(const half_t& x) @@ -317,5 +374,5 @@ half_t exp2(half_t x) { return static_cast(exp2f(static_cast(x))) CK_TILE_DEVICE half_t log(half_t x) { return static_cast(__logf(static_cast(x))); }; - +#endif } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/arithmetic.hpp b/include/ck_tile/core/numeric/numeric.hpp similarity index 69% rename from include/ck_tile/core/numeric/arithmetic.hpp rename to include/ck_tile/core/numeric/numeric.hpp index ad45a45e15..35745b12d2 100644 --- a/include/ck_tile/core/numeric/arithmetic.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -1,9 +1,103 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. -#include #pragma once +#include "ck_tile/core/config.hpp" +#include +#include + +namespace ck_tile { + +// this struct has the information of +// 1. limit of a certain type, simliar to std::numeric_limits +// 2. some pre-defined value, zero, one... +// +template +struct numeric +{ + // minimum finite value, or minimum positive normalized value for float + CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits::min(); } + + // minumum finite value + CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits::lowest(); } + + // maximum finite value + CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits::max(); } + + // difference between 1.0 and next value representable by float + CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits::epsilon(); } + + // maximum rounding error + CK_TILE_HOST_DEVICE static constexpr T round_error() + { + return std::numeric_limits::round_error(); + } + + // positive infinity value + CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits::infinity(); } + + // quiet NaN + CK_TILE_HOST_DEVICE static constexpr T quiet_NaN() + { + return std::numeric_limits::quiet_NaN(); + } + + // signaling NaN + CK_TILE_HOST_DEVICE static constexpr T signaling_NaN() + { + return std::numeric_limits::signaling_NaN(); + } + + // smallest positive subnormal value + CK_TILE_HOST_DEVICE static constexpr T denorm_min() + { + return std::numeric_limits::denorm_min(); + } + + CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast(0); } + + CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast(1); } + +#ifndef C_LOG2E +#define C_LOG2E 1.44269504088896340736 // log2(e) +#endif + + CK_TILE_HOST_DEVICE static constexpr T log2e() + { + if constexpr(std::is_same_v || std::is_same_v) + { + return static_cast(C_LOG2E); + } + else + { + return 0; // TODO: integer? + } + } +}; + +template +struct numeric_traits; + +template <> +struct numeric_traits +{ + static constexpr int exp = 8; + static constexpr int mant = 23; + static constexpr int bias = 127; + static constexpr uint32_t nan_mask = 0x7F800000; + static constexpr uint32_t head_mask = 0xFF800000; + static constexpr uint32_t mant_mask = 0x7FFFFF; + static constexpr uint32_t exp_mask = 0xFF; + static constexpr uint32_t Inf = 0x7F800000; + static constexpr uint32_t NegInf = 0xFF800000; + static constexpr uint32_t NaN = 0x7F800001; + static constexpr uint32_t Neg0 = 0x80000000; + using bitwise_type = uint32_t; +}; + +} // namespace ck_tile + #define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \ attr_ bool operator==(const type_& x, const type_& y) \ { \ diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 81bd55ee86..cb18cde70d 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -13,13 +13,13 @@ namespace ck_tile { +#if CK_TILE_USE_CUSTOM_DATA_TYPE template CK_TILE_HOST_DEVICE constexpr remove_cvref_t type_convert(const X& x) { return static_cast(x); } - -#if 0 +#else // Convert X to Y, both X and Y are non-const data types. template (type_convert(x)); } -#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \ +#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \ template <> \ CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ { \ - return stype_##_to_##dtype_(x); \ + return sname_##_to_##dname_(x); \ } -CK_TILE_TYPE_CONVERT(float, fp16_t) -CK_TILE_TYPE_CONVERT(float, bf16_t) -CK_TILE_TYPE_CONVERT(float, fp8_t) -CK_TILE_TYPE_CONVERT(float, bf8_t) +CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16) +CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16) +CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8) +CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8) -CK_TILE_TYPE_CONVERT(fp16_t, float) -CK_TILE_TYPE_CONVERT(bf16_t, float) -CK_TILE_TYPE_CONVERT(fp8_t, float) -CK_TILE_TYPE_CONVERT(bf8_t, float) +CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float) +CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float) +CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float) +CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) #undef CK_TILE_TYPE_CONVERT #endif diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 4d02937992..78cd054180 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -14,6 +14,17 @@ namespace ck_tile { +// this structure is used to pick up the type inside +// using xxx = __attribute__((ext_vector_type(N))); +// because clang only allow native type + bool in this term (custom type will fail) +// overload this structure to let proper type + +template +struct native_t +{ + using type = remove_cvref_t; +}; + // we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay // basic type to construct a ext_vector_type you must be very careful using this, or will have lot // of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will @@ -23,7 +34,7 @@ template struct ext_vector { static constexpr index_t N = N_; - using value_type = T_; + using value_type = typename native_t>::type; static_assert(!std::is_class_v); using type = value_type __attribute__((ext_vector_type(N))); // this is danguous }; @@ -52,10 +63,12 @@ struct vector_traits // below are some pre-defines of ext_vector_type // attention! 2 vector type could be just the same type // fp64 +using fp64_t = double; using fp64x2_t = double __attribute__((ext_vector_type(2))); using fp64x4_t = double __attribute__((ext_vector_type(4))); // fp32 +using fp32_t = float; using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp32x4_t = float __attribute__((ext_vector_type(4))); using fp32x8_t = float __attribute__((ext_vector_type(8))); @@ -64,6 +77,7 @@ using fp32x32_t = float __attribute__((ext_vector_type(32))); using fp32x64_t = float __attribute__((ext_vector_type(64))); // fp16 +// using fp16_t = ... using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using fp16x4_t = _Float16 __attribute__((ext_vector_type(4))); using fp16x8_t = _Float16 __attribute__((ext_vector_type(8))); @@ -71,7 +85,8 @@ using fp16x16_t = _Float16 __attribute__((ext_vector_type(16))); using fp16x32_t = _Float16 __attribute__((ext_vector_type(32))); using fp16x64_t = _Float16 __attribute__((ext_vector_type(64))); -// bfp16 +// bf16 +// using bf16_t = ... using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4))); using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8))); @@ -80,6 +95,7 @@ using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32))); using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64))); // i32 +// using int32_t = ... using int32x2_t = int32_t __attribute__((ext_vector_type(2))); using int32x4_t = int32_t __attribute__((ext_vector_type(4))); using int32x8_t = int32_t __attribute__((ext_vector_type(8))); @@ -88,6 +104,7 @@ using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x64_t = int32_t __attribute__((ext_vector_type(64))); // i16 +// using int16_t = ... using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x4_t = int16_t __attribute__((ext_vector_type(4))); using int16x8_t = int16_t __attribute__((ext_vector_type(8))); @@ -96,6 +113,7 @@ using int16x32_t = int16_t __attribute__((ext_vector_type(32))); using int16x64_t = int16_t __attribute__((ext_vector_type(64))); // u16 +// using uint16_t using uint16x2_t = uint16_t __attribute__((ext_vector_type(2))); using uint16x4_t = uint16_t __attribute__((ext_vector_type(4))); using uint16x8_t = uint16_t __attribute__((ext_vector_type(8))); @@ -104,6 +122,7 @@ using uint16x32_t = uint16_t __attribute__((ext_vector_type(32))); using uint16x64_t = uint16_t __attribute__((ext_vector_type(64))); // i8 +// using int8_t using int8x2_t = int8_t __attribute((ext_vector_type(2))); using int8x4_t = int8_t __attribute((ext_vector_type(4))); using int8x8_t = int8_t __attribute((ext_vector_type(8))); @@ -112,6 +131,7 @@ using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); // f8 +// using fp8_t using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4))); using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8))); @@ -120,6 +140,7 @@ using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32))); using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64))); // bf8 +// using bf8_t using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2))); using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4))); using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); diff --git a/include/ck_tile/core/tensor/buffer_view.hpp b/include/ck_tile/core/tensor/buffer_view.hpp index efb4f2ad43..96b38241c0 100644 --- a/include/ck_tile/core/tensor/buffer_view.hpp +++ b/include/ck_tile/core/tensor/buffer_view.hpp @@ -115,7 +115,7 @@ struct buffer_view>::zero()}; } else { @@ -319,7 +319,7 @@ struct buffer_view>::zero()}; } else { @@ -666,14 +666,18 @@ struct buffer_view(&p_data_[i]); + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + // using buf_t = ushort __attribute__((ext_vector_type(8))); + auto rtn = *c_style_pointer_cast(&p_data_[i]); + return bit_cast(rtn); #endif } else { if constexpr(InvalidElementUseNumericalZeroValue) { - return X{0}; + return X{numeric>::zero()}; } else { @@ -829,7 +833,10 @@ struct buffer_view(&p_data_[i]) = x; + using buf_t = ext_vector_t>::scalar_type, + scalar_per_t_vector * scalar_per_x_vector>; + + *c_style_pointer_cast(&p_data_[i]) = reinterpret_cast(x); #endif } } @@ -948,7 +955,7 @@ struct buffer_view>::zero()}; } else { diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index d8dc4e5b53..6bcba4019c 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -10,7 +10,7 @@ #include "ck_tile/core/container/container_helper.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/core/utility/type_traits.hpp" -#include "ck_tile/core/utility/limits.hpp" +#include "ck_tile/core/numeric/numeric.hpp" namespace ck_tile { @@ -511,7 +511,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a // shift constexpr index_t adaptor0_max_hidden_id = [&]() { - index_t adaptor0_max_hidden_id_ = numeric_limits::min(); + index_t adaptor0_max_hidden_id_ = numeric::min(); static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = @@ -537,7 +537,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }(); constexpr index_t adaptor1_min_hidden_id = [&]() { - index_t adaptor1_min_hidden_id_ = numeric_limits::max(); + index_t adaptor1_min_hidden_id_ = numeric::max(); static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) { constexpr index_t ndim_low = diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 90ad94b12b..95a272c8d2 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -40,7 +40,8 @@ template ::type; // TODO: make sure all distributed tensors have same lengths and distribution // static_assert(xxx); @@ -53,7 +54,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, static_for<0, thread_buffer_size, 1>{}([&](auto i) { out_dstr_tensor.get_thread_buffer()(i) = - in_element_func(in_dstr_tensors.get_thread_buffer()[i]...); + static_cast(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...)); }); return out_dstr_tensor; diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index dc6f482abd..09a4eb1fc0 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -303,7 +303,7 @@ struct tile_window_with_static_distribution const vector_t vec_value = get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, bool_constant{}); - +#if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) { constexpr auto idx_ys = generate_array( @@ -319,7 +319,14 @@ struct tile_window_with_static_distribution dst_tensor.get_thread_buffer().template at() = vec_value.template get_as()[j]; }); +#else + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start); + static_assert(d % Traits::ScalarPerVector == 0); + dst_tensor.get_thread_buffer().template get_as()( + number{}) = bit_cast(vec_value); +#endif // move thread coordinate if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) { diff --git a/include/ck_tile/core/utility/limits.hpp b/include/ck_tile/core/utility/limits.hpp deleted file mode 100644 index 9a3987c177..0000000000 --- a/include/ck_tile/core/utility/limits.hpp +++ /dev/null @@ -1,75 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core/config.hpp" -#include -#include - -namespace ck_tile { - -template -struct numeric_limits -{ - // minimum finite value, or minimum positive normalized value for float - CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits::min(); } - - // minumum finite value - CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits::lowest(); } - - // maximum finite value - CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits::max(); } - - // difference between 1.0 and next value representable by float - CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits::epsilon(); } - - // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr T round_error() - { - return std::numeric_limits::round_error(); - } - - // positive infinity value - CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits::infinity(); } - - // quiet NaN - CK_TILE_HOST_DEVICE static constexpr T quiet_NaN() - { - return std::numeric_limits::quiet_NaN(); - } - - // signaling NaN - CK_TILE_HOST_DEVICE static constexpr T signaling_NaN() - { - return std::numeric_limits::signaling_NaN(); - } - - // smallest positive subnormal value - CK_TILE_HOST_DEVICE static constexpr T denorm_min() - { - return std::numeric_limits::denorm_min(); - } -}; - -template -struct numeric_utils; - -template <> -struct numeric_utils -{ - static constexpr int exp = 8; - static constexpr int mant = 23; - static constexpr int bias = 127; - static constexpr uint32_t nan_mask = 0x7F800000; - static constexpr uint32_t head_mask = 0xFF800000; - static constexpr uint32_t mant_mask = 0x7FFFFF; - static constexpr uint32_t exp_mask = 0xFF; - static constexpr uint32_t Inf = 0x7F800000; - static constexpr uint32_t NegInf = 0xFF800000; - static constexpr uint32_t NaN = 0x7F800001; - static constexpr uint32_t Neg0 = 0x80000000; - using bitwise_type = uint32_t; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_batched_masking.hpp b/include/ck_tile/host/reference/reference_batched_masking.hpp index fcd603fe30..eece7fc3a8 100644 --- a/include/ck_tile/host/reference/reference_batched_masking.hpp +++ b/include/ck_tile/host/reference/reference_batched_masking.hpp @@ -21,7 +21,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor& c_b_m_n, cons for(int m = 0; m < M; ++m) { if(mask.IsOutOfBound(m, n)) - c_b_m_n(batch, m, n) = -ck_tile::numeric_limits::infinity(); + c_b_m_n(batch, m, n) = -ck_tile::numeric::infinity(); } } }; diff --git a/include/ck_tile/host/reference/reference_batched_softmax.hpp b/include/ck_tile/host/reference/reference_batched_softmax.hpp index 4f748e0f0c..57de0ca243 100644 --- a/include/ck_tile/host/reference/reference_batched_softmax.hpp +++ b/include/ck_tile/host/reference/reference_batched_softmax.hpp @@ -18,7 +18,7 @@ CK_TILE_HOST void reference_batched_softmax( const int N = a_b_m_n.mDesc.get_lengths()[2]; auto f = [&](auto batch, auto m) { - CompDataType v_max = -ck_tile::numeric_limits::infinity(); + CompDataType v_max = -ck_tile::numeric::infinity(); // max for(int n = 0; n < N; ++n) diff --git a/include/ck_tile/host/reference/reference_softmax.hpp b/include/ck_tile/host/reference/reference_softmax.hpp index f73579fd5c..f1404f85a8 100644 --- a/include/ck_tile/host/reference/reference_softmax.hpp +++ b/include/ck_tile/host/reference/reference_softmax.hpp @@ -16,7 +16,7 @@ CK_TILE_HOST void reference_softmax(const HostTensor& a_m_n, auto f = [&](auto m) { const int N = a_m_n.mDesc.get_lengths()[1]; - AccDataType v_max = ck_tile::numeric_limits::Lowest(); + AccDataType v_max = ck_tile::numeric::Lowest(); // max for(int n = 0; n < N; ++n) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 1fe6415453..98866805a0 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -50,7 +50,7 @@ struct FmhaFwdKernel // clang-format off template struct t2s; template <> struct t2s { static constexpr const char * name = "fp32"; }; - template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; template <> struct t2s { static constexpr const char * name = "bf16"; }; template <> struct t2s { static constexpr const char * name = "fp8"; }; template <> struct t2s { static constexpr const char * name = "bf8"; }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index c2953fc2ea..098d9d363d 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -190,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric_limits::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); @@ -209,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric_limits::infinity()); + set_tile(lse, -numeric::infinity()); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -347,15 +347,12 @@ struct BlockFmhaPipelineQRKSVS number{}); if(need_perpixel_check) { - set_tile_if(s_acc, - -numeric_limits::infinity(), - [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } @@ -364,7 +361,7 @@ struct BlockFmhaPipelineQRKSVS s, sequence<1>{}, f_max, - -numeric_limits::infinity()); // m_local = rowmax(S{j}) + -numeric::infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -379,7 +376,7 @@ struct BlockFmhaPipelineQRKSVS /// consideration if constexpr(kHasBias || FmhaMask::IsMasking) { - return raw_m == -numeric_limits::infinity() + return raw_m == -numeric::infinity() ? type_convert(0.f) : raw_m; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 76c20bfe46..c7e2f3ae4b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -232,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric_limits::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); __builtin_amdgcn_sched_barrier(0); @@ -252,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric_limits::infinity()); + set_tile(lse, -numeric::infinity()); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -390,15 +390,12 @@ struct BlockFmhaPipelineQRKSVSAsync number{}); if(need_perpixel_check) { - set_tile_if(s_acc, - -numeric_limits::infinity(), - [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } @@ -407,7 +404,7 @@ struct BlockFmhaPipelineQRKSVSAsync s, sequence<1>{}, f_max, - -numeric_limits::infinity()); // m_local = rowmax(S{j}) + -numeric::infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -458,7 +455,7 @@ struct BlockFmhaPipelineQRKSVSAsync /// consideration if constexpr(kHasBias || FmhaMask::IsMasking) { - return raw_m == -numeric_limits::infinity() + return raw_m == -numeric::infinity() ? type_convert(0.f) : raw_m; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 0643e7c0d9..5476282e04 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -182,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8 auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric_limits::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); const auto q_origin = q_dram_window.get_window_origin(); @@ -330,15 +330,12 @@ struct BlockFmhaPipelineQRKSVSFp8 number{}); if(need_perpixel_check) { - set_tile_if(s_acc, - -numeric_limits::infinity(), - [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } @@ -347,7 +344,7 @@ struct BlockFmhaPipelineQRKSVSFp8 s, sequence<1>{}, f_max, - -numeric_limits::infinity()); // m_local = rowmax(S{j}) + -numeric::infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -362,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSFp8 /// consideration if constexpr(kHasBias || FmhaMask::IsMasking) { - return raw_m == -numeric_limits::infinity() + return raw_m == -numeric::infinity() ? type_convert(0.f) : raw_m; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index e7fa19449b..e04f6660d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS auto l = MLBlockTileType{}; clear_tile(o_acc); - set_tile(m, -numeric_limits::infinity()); + set_tile(m, -numeric::infinity()); clear_tile(l); const auto q_origin = q_dram_block_window_tmp.get_window_origin(); @@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS auto lse = make_static_distributed_tensor(m.get_tile_distribution()); - set_tile(lse, -numeric_limits::infinity()); + set_tile(lse, -numeric::infinity()); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); } @@ -338,15 +338,12 @@ struct BlockFmhaPipelineQSKSVS number{}); if(need_perpixel_check) { - set_tile_if(s_acc, - -numeric_limits::infinity(), - [&](auto tile_idx) { - const auto row = - q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); - const auto col = - k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); - }); + set_tile_if( + s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsOutOfBound(row, col); + }); } } @@ -355,7 +352,7 @@ struct BlockFmhaPipelineQSKSVS s, sequence<1>{}, f_max, - -numeric_limits::infinity()); // m_local = rowmax(S{j}) + -numeric::infinity()); // m_local = rowmax(S{j}) block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} @@ -370,7 +367,7 @@ struct BlockFmhaPipelineQSKSVS /// consideration if constexpr(kHasBias || FmhaMask::IsMasking) { - return raw_m == -numeric_limits::infinity() + return raw_m == -numeric::infinity() ? type_convert(0.f) : raw_m; } diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index cbe37e8769..c7ebcf9606 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -29,4 +29,3 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -