diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 9ac55c1197..6b1c11fa27 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -54,4 +54,3 @@ #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" - diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index 8ac9545633..071387163a 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -267,7 +267,15 @@ struct numeric } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); } + // maximum rounding error + // bin : f edcba 9876543210 + // bits: s eeeeeeee mmmmmmm + // 0 01111110 0000000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() + { + return bit_cast(static_cast(0x3f00)); + } // positive infinity value CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity() diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index 05f65309d9..bad1009f2c 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -42,13 +42,13 @@ enum class fp8_rounding_mode */ template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant = {}); +CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant = {}); template (CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant = {}); +CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant = {}); -CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t); -CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t); +CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t); #if CK_TILE_USE_CUSTOM_DATA_TYPE struct alignas(1) float8_e4m3_t @@ -581,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x) // clang-format off template -CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant) +CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x); @@ -589,14 +589,14 @@ CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant -CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant) +CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant) { if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x); else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x); else return bf8_raw_t{0}; } -CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x) +CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -610,7 +610,7 @@ CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x) #endif } -CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x) +CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x) { #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) float fval; @@ -625,23 +625,23 @@ CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x) } template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant = {}) +CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant = {}) { return bit_cast(float_to_fp8_raw(x, constant{})); } template(CK_TILE_FLOAT_TO_FP8_DEFAULT)> -CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant = {}) +CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant = {}) { return bit_cast(float_to_bf8_raw(x, constant{})); } -CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x) +CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast(x)); } -CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x) +CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast(x)); } @@ -706,7 +706,14 @@ struct numeric } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); } + // bin : 7 6543 210 + // bits: s eeee mmm + // 0 0110 000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() + { + return bit_cast(static_cast(0x30)); + } // positive infinity value CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() @@ -766,7 +773,14 @@ struct numeric } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); } + // bin : 7 65432 10 + // bits: s eeeee mm + // 0 01110 00 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() + { + return bit_cast(static_cast(0x38)); + } // positive infinity value CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index dfe1d6461c..c616b6939f 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -184,7 +184,14 @@ struct numeric } // maximum rounding error - CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast(0.5f); } + // bin : f edcba 9876543210 + // bits: s eeeee mmmmmmmmmm + // 0 01110 0000000000 (0.5) + // + CK_TILE_HOST_DEVICE static constexpr half_t round_error() + { + return bit_cast(static_cast(0x3800)); + } // positive infinity value CK_TILE_HOST_DEVICE static constexpr half_t infinity() diff --git a/include/ck_tile/core/numeric/vector_type.hpp b/include/ck_tile/core/numeric/vector_type.hpp index 78cd054180..9d09e06230 100644 --- a/include/ck_tile/core/numeric/vector_type.hpp +++ b/include/ck_tile/core/numeric/vector_type.hpp @@ -130,6 +130,7 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16))); using int8x32_t = int8_t __attribute((ext_vector_type(32))); using int8x64_t = int8_t __attribute((ext_vector_type(64))); +#if CK_TILE_USE_CUSTOM_DATA_TYPE // f8 // using fp8_t using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2))); @@ -147,5 +148,24 @@ using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8))); using bf8x16_t = bf8_raw_t __attribute((ext_vector_type(16))); using bf8x32_t = bf8_raw_t __attribute((ext_vector_type(32))); using bf8x64_t = bf8_raw_t __attribute((ext_vector_type(64))); +#else +// f8 +// using fp8_t +using fp8x2_t = fp8_t __attribute((ext_vector_type(2))); +using fp8x4_t = fp8_t __attribute((ext_vector_type(4))); +using fp8x8_t = fp8_t __attribute((ext_vector_type(8))); +using fp8x16_t = fp8_t __attribute((ext_vector_type(16))); +using fp8x32_t = fp8_t __attribute((ext_vector_type(32))); +using fp8x64_t = fp8_t __attribute((ext_vector_type(64))); + +// bf8 +// using bf8_t +using bf8x2_t = bf8_t __attribute((ext_vector_type(2))); +using bf8x4_t = bf8_t __attribute((ext_vector_type(4))); +using bf8x8_t = bf8_t __attribute((ext_vector_type(8))); +using bf8x16_t = bf8_t __attribute((ext_vector_type(16))); +using bf8x32_t = bf8_t __attribute((ext_vector_type(32))); +using bf8x64_t = bf8_t __attribute((ext_vector_type(64))); +#endif } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 95a272c8d2..90ad94b12b 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -40,8 +40,7 @@ template ::type; + using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...)); // TODO: make sure all distributed tensors have same lengths and distribution // static_assert(xxx); @@ -54,7 +53,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func, static_for<0, thread_buffer_size, 1>{}([&](auto i) { out_dstr_tensor.get_thread_buffer()(i) = - static_cast(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...)); + in_element_func(in_dstr_tensors.get_thread_buffer()[i]...); }); return out_dstr_tensor; diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 1bbb4b9539..0c4a778226 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -20,4 +20,3 @@ #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" - diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 9fc1c0d0c1..4363ea1f55 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -4,4 +4,3 @@ #pragma once #include "ck_tile/ops/common/tensor_layout.hpp" - diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index ab399dbf7a..388f52c898 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -5,4 +5,3 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" - diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index f886d470d5..1e9acc6d7b 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -18,4 +18,3 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" -