not using custom data type by default, now we can have ISA-level same code as opt_padding

This commit is contained in:
carlushuang
2024-03-17 23:23:32 +00:00
parent ee397d0ab2
commit f55c7629bc
26 changed files with 629 additions and 328 deletions

View File

@@ -11,7 +11,7 @@ import copy
import fnmatch
DTYPE_MAP = {
"fp16": "ck_tile::half_t",
"fp16": "ck_tile::fp16_t",
"bf16": "ck_tile::bf16_t",
"fp8" : "ck_tile::fp8_t"
}

View File

@@ -20,13 +20,13 @@
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
@@ -49,7 +49,6 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"

View File

@@ -47,8 +47,8 @@ struct base_transform
{
if constexpr(NDimUp > 0)
{
array<index_t, NDimUp> up_vector_lengths = make_array_with<index_t, NDimUp>({-1});
array<index_t, NDimUp> up_vector_strides = make_array_with<index_t, NDimUp>({-1});
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
return make_tuple(up_vector_lengths, up_vector_strides);
}
@@ -690,8 +690,8 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths = make_array_with<index_t, 1>({-1});
array<index_t, 1> up_vector_strides = make_array_with<index_t, 1>({-1});
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
@@ -821,8 +821,8 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths = make_array_with<index_t, 1>({-1});
array<index_t, 1> up_vector_strides = make_array_with<index_t, 1>({-1});
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
@@ -940,8 +940,8 @@ struct unmerge : public base_transform<1, UpLengths::size()>
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, NDimUp> up_vector_lengths = make_array_with<index_t, NDimUp>({-1});
array<index_t, NDimUp> up_vector_strides = make_array_with<index_t, NDimUp>({-1});
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];

View File

@@ -414,7 +414,7 @@ struct buffer_store_if<8>
static_assert(sizeof(T) == 8);
auto save_exec = __builtin_amdgcn_read_exec();
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
using mbuf_t = ext_vector_t<typename T::value_type::raw_type, T::size()>;
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
"s_mov_b64 exec %6"
@@ -1778,7 +1778,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
thread_buffer<T, N> tmp =
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
if constexpr(oob_conditional_check)
return src_thread_element_valid ? tmp : thread_buffer<T, N>{0};
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
else
return tmp;
#endif

View File

@@ -20,6 +20,10 @@
#define CK_TILE_DEVICE_EXTERN
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2

View File

@@ -19,6 +19,8 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
return make_tuple(ts...);
}
#else
#if 0
template <typename T, index_t N>
using thread_buffer = array<T, N>;
@@ -27,6 +29,72 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
{
return make_array(ts...);
}
#endif
// clang-format off
template<typename T_, index_t N_>
struct thread_buffer {
using value_type = remove_cvref_t<T_>;
static constexpr index_t N = N_;
value_type data[N];
// TODO: this ctor can't ignore
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
CK_TILE_HOST_DEVICE auto & get() {return data; }
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
#define TB_COMMON_AS() \
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
template<typename Tx>
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS();
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
template<typename Tx>
CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS();
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(i);}
template<typename Tx>
CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS();
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(i);}
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
#undef TB_COMMON_AS
};
// clang-format on
template <typename>
struct vector_traits;
// specialization for array
template <typename T, index_t N>
struct vector_traits<thread_buffer<T, N>>
{
using scalar_type = T;
static constexpr index_t vector_size = N;
};
#endif
} // namespace ck_tile

View File

@@ -3,10 +3,9 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#pragma once
@@ -22,18 +21,19 @@ enum class bf16_rounding_mode
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x);
constexpr float bf16_to_float_raw(uint16_t x);
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x);
constexpr double bf16_to_double_raw(uint16_t x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use __hip_bfloat16 as struct
struct alignas(2) bfloat16_t
{
@@ -89,13 +89,24 @@ struct alignas(2) bfloat16_t
CK_TILE_HOST_DEVICE
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<bfloat16_t>
{
using type = ushort;
};
using bf16_t = bfloat16_t;
using bf16_raw_t = typename bf16_t::raw_type;
#else
using bfloat16_t = ushort;
using bf16_t = bfloat16_t;
using bf16_raw_t = uint16_t;
#endif
// round to nearest
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_rtn_raw(float f)
constexpr uint16_t float_to_bf16_rtn_raw(float f)
{
union
{
@@ -139,7 +150,7 @@ uint16_t float_to_bf16_rtn_raw(float f)
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_nan_raw(float f)
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
{
union
{
@@ -151,7 +162,7 @@ uint16_t float_to_bf16_truc_nan_raw(float f)
// Fast truncate instead of rounding, RTZ
CK_TILE_HOST_DEVICE
uint16_t float_to_bf16_truc_raw(float f)
constexpr uint16_t float_to_bf16_truc_raw(float f)
{
union
{
@@ -162,7 +173,7 @@ uint16_t float_to_bf16_truc_raw(float f)
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
{
if constexpr(rounding == bf16_rounding_mode::standard)
return float_to_bf16_rtn_raw(f);
@@ -173,13 +184,13 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
}
template <bf16_rounding_mode rounding>
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
{
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
}
CK_TILE_HOST_DEVICE
float bf16_to_float_raw(uint16_t x)
constexpr float bf16_to_float_raw(uint16_t x)
{
union
{
@@ -190,100 +201,118 @@ float bf16_to_float_raw(uint16_t x)
}
CK_TILE_HOST_DEVICE
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
constexpr double bf16_to_double_raw(uint16_t x)
{
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
return static_cast<double>(bf16_to_float_raw(x));
}
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
{
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
CK_TILE_HOST_DEVICE
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
{
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
}
CK_TILE_HOST_DEVICE
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
template <class T>
struct numeric_limits;
struct numeric;
template <>
struct numeric_limits<bfloat16_t>
struct numeric<bfloat16_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
{
return bfloat16_t::bit_cast(0xff7f);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
{
return bfloat16_t::bit_cast(0x1000);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
{
return bfloat16_t::bit_cast(0x7f80);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
{
return bfloat16_t::bit_cast(0x7FFF);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
{
return bfloat16_t::bit_cast(0x0001);
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
#endif
// math
CK_TILE_HOST_DEVICE
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
bfloat16_t abs(const bfloat16_t& x)
{
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
}
CK_TILE_HOST_DEVICE
bool isnan(const bfloat16_t& x)
{
uint16_t xx = x.get();
uint16_t xx = bit_cast<bf16_raw_t>(x);
return (xx & 0x7FFF) > 0x7C00;
}

View File

@@ -3,13 +3,12 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <stdint.h>
#include <type_traits>
@@ -43,14 +42,15 @@ enum class fp8_rounding_mode
*/
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant<rounding> = {});
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant<rounding> = {});
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t);
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
struct alignas(1) float8_e4m3_t
{
static constexpr int exponent = 4;
@@ -58,7 +58,7 @@ struct alignas(1) float8_e4m3_t
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
@@ -116,7 +116,7 @@ struct alignas(1) float8_e5m2_t
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 1 << (exponent - 1); // NANOO
#else
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
#endif
using raw_type = uint8_t;
raw_type data;
@@ -167,6 +167,28 @@ struct alignas(1) float8_e5m2_t
using bf8_t = float8_e5m2_t;
using bf8_raw_t = typename bf8_t::raw_type;
template <typename>
struct native_t;
template <>
struct native_t<fp8_t>
{
using type = _BitInt(8);
};
template <>
struct native_t<bf8_t>
{
using type = unsigned _BitInt(8);
};
#else
using fp8_t = _BitInt(8);
using fp8_raw_t = uint8_t;
using bf8_t = unsigned _BitInt(8);
using bf8_raw_t = uint8_t;
#endif
// below is sw fp8 conversion, not utilizing hw instruction
namespace impl {
@@ -174,29 +196,35 @@ template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
{
// fp8/bf8 exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
// original type exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
int exponent, bias;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
#if CK_TILE_USE_CUSTOM_DATA_TYPE
constexpr Y nan_code =
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
#else
constexpr Y nan_code = 0x80;
#endif
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask;
// convert to bitwise
using T_bitwise = typename numeric_utils<X>::bitwise_type;
using T_bitwise = typename numeric_traits<X>::bitwise_type;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
head = x_bitwise & numeric_utils<X>::head_mask;
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
head = x_bitwise & numeric_traits<X>::head_mask;
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
sign = head >> (in_exp + in_mant);
bias = numeric_utils<X>::bias;
bias = numeric_traits<X>::bias;
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
@@ -335,23 +363,23 @@ template <typename X, typename Y, bool negative_zero_nan>
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
{
// fp8/bf8 exponent/mantissa layout
constexpr int in_exp = numeric_utils<X>::exp;
constexpr int in_mant = numeric_utils<X>::mant;
constexpr int in_exp = numeric_traits<X>::exp;
constexpr int in_mant = numeric_traits<X>::mant;
// resulting type exponent/mantissa layout
constexpr int out_exp = numeric_utils<Y>::exp;
constexpr int out_mant = numeric_utils<Y>::mant;
constexpr int out_exp = numeric_traits<Y>::exp;
constexpr int out_mant = numeric_traits<Y>::mant;
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
// prepare the codes
constexpr uint8_t nan_code = 0x80;
Y Inf, NegInf, NaN, Neg0;
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
using T_bitwise = typename numeric_traits<Y>::bitwise_type;
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf;
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN;
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0;
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
@@ -384,7 +412,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
}
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan)
{
retval = x_raw;
retval <<= 8;
@@ -553,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
// clang-format off
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
@@ -561,14 +589,14 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
}
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
{
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
else return bf8_raw_t{0};
}
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -578,11 +606,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(fp8_t::bit_cast(x));
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
#endif
}
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval;
@@ -592,129 +620,200 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
return fval;
#else
constexpr bool negative_zero_nan = true;
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bf8_t::bit_cast(x));
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
#endif
}
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding>)
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant<rounding> = {})
{
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
}
template<fp8_rounding_mode rounding>
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding>)
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant<rounding> = {})
{
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
}
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x)
{
return fp8_to_float_raw(x.get());
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
}
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x)
{
return bf8_to_float_raw(x.get());
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
}
// clang-format on
template <typename T>
struct numeric_utils;
struct numeric_traits;
template <>
struct numeric_utils<fp8_t>
struct numeric_traits<fp8_t>
{
static constexpr int exp = fp8_t::exponent;
static constexpr int mant = fp8_t::mantissa;
static constexpr int bias = fp8_t::bias;
static constexpr int exp = 4;
static constexpr int mant = 3;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 8;
#else
static constexpr int bias = 7;
#endif
};
template <>
struct numeric_utils<bf8_t>
struct numeric_traits<bf8_t>
{
static constexpr int exp = bf8_t::exponent;
static constexpr int mant = bf8_t::mantissa;
static constexpr int bias = bf8_t::bias;
static constexpr int exp = 5;
static constexpr int mant = 2;
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
static constexpr int bias = 16;
#else
static constexpr int bias = 15; // IEEE
#endif
};
template <class T>
struct numeric_limits;
struct numeric;
template <>
struct numeric_limits<fp8_t>
struct numeric<fp8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
CK_TILE_HOST_DEVICE static constexpr fp8_t min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
}
};
template <>
struct numeric_limits<bf8_t>
struct numeric<bf8_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
CK_TILE_HOST_DEVICE static constexpr bf8_t min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
}
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
{
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
}
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
#endif
// math
CK_TILE_HOST_DEVICE
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
fp8_t abs(const fp8_t& x)
{
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f));
}
CK_TILE_HOST_DEVICE
bool isnan(const fp8_t& x)
{
uint8_t xx = x.get();
uint8_t xx = bit_cast<fp8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}
@@ -731,12 +830,15 @@ CK_TILE_DEVICE
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
CK_TILE_HOST_DEVICE
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
bf8_t abs(const bf8_t& x)
{
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
}
CK_TILE_HOST_DEVICE
bool isnan(const bf8_t& x)
{
uint8_t xx = x.get();
uint8_t xx = bit_cast<bf8_raw_t>(x);
return xx == 0x80; // TODO: NANOO
}

View File

@@ -2,33 +2,34 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/arithmetic.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include <hip/hip_fp16.h>
#pragma once
namespace ck_tile {
using fp16_hip_t = __half; // most of hip internal function use this type
using fp16_hip_t = _Float16; // most of hip internal function use this type
using fp16_raw_t = uint16_t;
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x);
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x);
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x);
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x);
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// HIP use fp16_hip_t as interchangable data type for float16
struct alignas(2) half_t
{
using raw_type = uint16_t;
using raw_type = fp16_raw_t;
raw_type data;
CK_TILE_HOST_DEVICE
@@ -83,6 +84,9 @@ struct alignas(2) half_t
return static_cast<int>(fp16_to_float_hip(to_fp16()));
}
CK_TILE_HOST_DEVICE
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
// internal access
CK_TILE_HOST_DEVICE
constexpr raw_type& get() { return data; }
@@ -91,86 +95,132 @@ struct alignas(2) half_t
constexpr raw_type get() const { return data; }
};
template <typename>
struct native_t;
template <>
struct native_t<half_t>
{
using type = _Float16;
};
using fp16_t = half_t;
using fp16_raw_t = typename fp16_t::raw_type;
using fp16_raw_t = typename half_t::raw_type;
#else
using fp16_t = _Float16;
using half_t = _Float16;
using fp16_raw_t = ushort;
#endif
// conversions
CK_TILE_HOST_DEVICE
float fp16_to_float_hip(const fp16_hip_t& x)
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
{
// return __half2float(x);
return static_cast<float>(x);
}
CK_TILE_HOST_DEVICE
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
{
return static_cast<double>(fp16_to_float_hip(x));
}
CK_TILE_HOST_DEVICE
fp16_hip_t float_to_fp16_hip(const float& x)
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
{
return __float2half(x);
// return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
fp16_hip_t double_to_fp16_hip(const double& x)
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
{
// return __float2half(x);
return static_cast<fp16_hip_t>(x);
}
CK_TILE_HOST_DEVICE
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
CK_TILE_HOST_DEVICE
half_t float_to_fp16(const float& x) { return half_t{x}; }
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
CK_TILE_HOST_DEVICE
half_t double_to_fp16(const double& x) { return half_t{x}; }
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
// limits
template <class T>
struct numeric_limits;
struct numeric;
template <>
struct numeric_limits<half_t>
struct numeric<half_t>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
CK_TILE_HOST_DEVICE static constexpr half_t min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
}
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
}
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
CK_TILE_HOST_DEVICE static constexpr half_t max()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
}
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast<half_t>(0.5f); }
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
}
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
}
CK_TILE_HOST_DEVICE static constexpr half_t zero()
{
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
}
};
template <typename T>
struct numeric_utils;
struct numeric_traits;
template <>
struct numeric_utils<half_t>
struct numeric_traits<half_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
@@ -186,9 +236,12 @@ struct numeric_utils<half_t>
using bitwise_type = uint16_t;
};
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// arithmetic
CK_TILE_DEVICE
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
{
return __heq(x.to_fp16(), y.to_fp16());
}
CK_TILE_DEVICE
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
@@ -205,6 +258,7 @@ bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.t
CK_TILE_DEVICE
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
#if 0
CK_TILE_DEVICE
half_t operator+(const half_t& x, const half_t& y)
{
@@ -289,12 +343,15 @@ half_t operator--(half_t& x, int)
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
return y;
}
#endif
#if CK_TILE_USE_CUSTOM_DATA_TYPE
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
#endif
// math
CK_TILE_HOST_DEVICE
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
CK_TILE_HOST_DEVICE
bool isnan(const half_t& x)
@@ -317,5 +374,5 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
CK_TILE_DEVICE
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
#endif
} // namespace ck_tile

View File

@@ -1,9 +1,103 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <stdint.h>
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template <typename T>
struct numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE static constexpr T log2e()
{
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
{
return static_cast<T>(C_LOG2E);
}
else
{
return 0; // TODO: integer?
}
}
};
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
} // namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \

View File

@@ -13,13 +13,13 @@
namespace ck_tile {
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template <typename Y, typename X>
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
{
return static_cast<Y>(x);
}
#if 0
#else
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
@@ -43,22 +43,22 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
}
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return stype_##_to_##dtype_(x); \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT(float, fp16_t)
CK_TILE_TYPE_CONVERT(float, bf16_t)
CK_TILE_TYPE_CONVERT(float, fp8_t)
CK_TILE_TYPE_CONVERT(float, bf8_t)
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
CK_TILE_TYPE_CONVERT(fp16_t, float)
CK_TILE_TYPE_CONVERT(bf16_t, float)
CK_TILE_TYPE_CONVERT(fp8_t, float)
CK_TILE_TYPE_CONVERT(bf8_t, float)
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
#undef CK_TILE_TYPE_CONVERT
#endif

View File

@@ -14,6 +14,17 @@
namespace ck_tile {
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template <typename T>
struct native_t
{
using type = remove_cvref_t<T>;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
@@ -23,7 +34,7 @@ template <typename T_, index_t N_>
struct ext_vector
{
static constexpr index_t N = N_;
using value_type = T_;
using value_type = typename native_t<remove_cvref_t<T_>>::type;
static_assert(!std::is_class_v<value_type>);
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
};
@@ -52,10 +63,12 @@ struct vector_traits<T __attribute__((ext_vector_type(N)))>
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using fp64_t = double;
using fp64x2_t = double __attribute__((ext_vector_type(2)));
using fp64x4_t = double __attribute__((ext_vector_type(4)));
// fp32
using fp32_t = float;
using fp32x2_t = float __attribute__((ext_vector_type(2)));
using fp32x4_t = float __attribute__((ext_vector_type(4)));
using fp32x8_t = float __attribute__((ext_vector_type(8)));
@@ -64,6 +77,7 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
using fp32x64_t = float __attribute__((ext_vector_type(64)));
// fp16
// using fp16_t = ...
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
@@ -71,7 +85,8 @@ using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
// bfp16
// bf16
// using bf16_t = ...
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
@@ -80,6 +95,7 @@ using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
// i32
// using int32_t = ...
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
@@ -88,6 +104,7 @@ using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// i16
// using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
@@ -96,6 +113,7 @@ using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
// u16
// using uint16_t
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
@@ -104,6 +122,7 @@ using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
// i8
// using int8_t
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
@@ -112,6 +131,7 @@ using int8x32_t = int8_t __attribute((ext_vector_type(32)));
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
// f8
// using fp8_t
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
@@ -120,6 +140,7 @@ using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
// bf8
// using bf8_t
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));

View File

@@ -115,7 +115,7 @@ struct buffer_view<address_space_enum::generic,
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{0};
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
@@ -319,7 +319,7 @@ struct buffer_view<address_space_enum::global,
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{0};
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
@@ -666,14 +666,18 @@ struct buffer_view<address_space_enum::lds,
return tmp;
#else
return *c_style_pointer_cast<const X*>(&p_data_[i]);
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
return bit_cast<X>(rtn);
#endif
}
else
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{0};
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{
@@ -829,7 +833,10 @@ struct buffer_view<address_space_enum::lds,
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
#else
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
scalar_per_t_vector * scalar_per_x_vector>;
*c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
#endif
}
}
@@ -948,7 +955,7 @@ struct buffer_view<address_space_enum::vgpr,
{
if constexpr(InvalidElementUseNumericalZeroValue)
{
return X{0};
return X{numeric<remove_cvref_t<T>>::zero()};
}
else
{

View File

@@ -10,7 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/limits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
namespace ck_tile {
@@ -511,7 +511,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
// shift
constexpr index_t adaptor0_max_hidden_id = [&]() {
index_t adaptor0_max_hidden_id_ = numeric_limits<index_t>::min();
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =
@@ -537,7 +537,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
}();
constexpr index_t adaptor1_min_hidden_id = [&]() {
index_t adaptor1_min_hidden_id_ = numeric_limits<index_t>::max();
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
constexpr index_t ndim_low =

View File

@@ -40,7 +40,8 @@ template <typename InElementFunc,
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
const InTensor&... in_dstr_tensors)
{
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
using OutNativeType = typename native_t<OutDataType>::type;
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
@@ -53,7 +54,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
out_dstr_tensor.get_thread_buffer()(i) =
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
static_cast<OutNativeType>(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...));
});
return out_dstr_tensor;

View File

@@ -303,7 +303,7 @@ struct tile_window_with_static_distribution
const vector_t vec_value =
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
#if 1
// write into distributed tensor
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_array(
@@ -319,7 +319,14 @@ struct tile_window_with_static_distribution
dst_tensor.get_thread_buffer().template at<d>() =
vec_value.template get_as<DataType>()[j];
});
#else
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
static_assert(d % Traits::ScalarPerVector == 0);
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
#endif
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{

View File

@@ -1,75 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace ck_tile {
template <typename T>
struct numeric_limits
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
// minumum finite value
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
// maximum finite value
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
// maximum rounding error
CK_TILE_HOST_DEVICE static constexpr T round_error()
{
return std::numeric_limits<T>::round_error();
}
// positive infinity value
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
// quiet NaN
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
{
return std::numeric_limits<T>::quiet_NaN();
}
// signaling NaN
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
{
return std::numeric_limits<T>::signaling_NaN();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
{
return std::numeric_limits<T>::denorm_min();
}
};
template <typename T>
struct numeric_utils;
template <>
struct numeric_utils<float>
{
static constexpr int exp = 8;
static constexpr int mant = 23;
static constexpr int bias = 127;
static constexpr uint32_t nan_mask = 0x7F800000;
static constexpr uint32_t head_mask = 0xFF800000;
static constexpr uint32_t mant_mask = 0x7FFFFF;
static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t Inf = 0x7F800000;
static constexpr uint32_t NegInf = 0xFF800000;
static constexpr uint32_t NaN = 0x7F800001;
static constexpr uint32_t Neg0 = 0x80000000;
using bitwise_type = uint32_t;
};
} // namespace ck_tile

View File

@@ -21,7 +21,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfBound(m, n))
c_b_m_n(batch, m, n) = -ck_tile::numeric_limits<CDataType>::infinity();
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}
};

View File

@@ -18,7 +18,7 @@ CK_TILE_HOST void reference_batched_softmax(
const int N = a_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch, auto m) {
CompDataType v_max = -ck_tile::numeric_limits<CompDataType>::infinity();
CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
// max
for(int n = 0; n < N; ++n)

View File

@@ -16,7 +16,7 @@ CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
// max
for(int n = 0; n < N; ++n)

View File

@@ -50,7 +50,7 @@ struct FmhaFwdKernel
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::half_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };

View File

@@ -190,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -209,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -347,15 +347,12 @@ struct BlockFmhaPipelineQRKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -364,7 +361,7 @@ struct BlockFmhaPipelineQRKSVS
s,
sequence<1>{},
f_max,
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -379,7 +376,7 @@ struct BlockFmhaPipelineQRKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -232,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
@@ -252,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -390,15 +390,12 @@ struct BlockFmhaPipelineQRKSVSAsync
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -407,7 +404,7 @@ struct BlockFmhaPipelineQRKSVSAsync
s,
sequence<1>{},
f_max,
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -458,7 +455,7 @@ struct BlockFmhaPipelineQRKSVSAsync
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -182,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_window.get_window_origin();
@@ -330,15 +330,12 @@ struct BlockFmhaPipelineQRKSVSFp8
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -347,7 +344,7 @@ struct BlockFmhaPipelineQRKSVSFp8
s,
sequence<1>{},
f_max,
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -362,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSFp8
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
auto lse =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
@@ -338,15 +338,12 @@ struct BlockFmhaPipelineQSKSVS
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc,
-numeric_limits<SMPLComputeDataType>::infinity(),
[&](auto tile_idx) {
const auto row =
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
@@ -355,7 +352,7 @@ struct BlockFmhaPipelineQSKSVS
s,
sequence<1>{},
f_max,
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
@@ -370,7 +367,7 @@ struct BlockFmhaPipelineQSKSVS
/// consideration
if constexpr(kHasBias || FmhaMask::IsMasking)
{
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}

View File

@@ -29,4 +29,3 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"