mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
not using custom data type by default, now we can have ISA-level same code as opt_padding
This commit is contained in:
@@ -11,7 +11,7 @@ import copy
|
||||
import fnmatch
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "ck_tile::half_t",
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp8" : "ck_tile::fp8_t"
|
||||
}
|
||||
|
||||
@@ -20,13 +20,13 @@
|
||||
#include "ck_tile/core/container/statically_indexed_array.hpp"
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/numeric/type_convert.hpp"
|
||||
#include "ck_tile/core/numeric/vector_type.hpp"
|
||||
#include "ck_tile/core/tensor/buffer_view.hpp"
|
||||
@@ -49,7 +49,6 @@
|
||||
#include "ck_tile/core/tensor/tile_window.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/utility/magic_div.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
|
||||
@@ -47,8 +47,8 @@ struct base_transform
|
||||
{
|
||||
if constexpr(NDimUp > 0)
|
||||
{
|
||||
array<index_t, NDimUp> up_vector_lengths = make_array_with<index_t, NDimUp>({-1});
|
||||
array<index_t, NDimUp> up_vector_strides = make_array_with<index_t, NDimUp>({-1});
|
||||
array<index_t, NDimUp> up_vector_lengths{-1};
|
||||
array<index_t, NDimUp> up_vector_strides{-1};
|
||||
|
||||
return make_tuple(up_vector_lengths, up_vector_strides);
|
||||
}
|
||||
@@ -690,8 +690,8 @@ struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
|
||||
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
|
||||
const LowVectorStrides& low_vector_strides)
|
||||
{
|
||||
array<index_t, 1> up_vector_lengths = make_array_with<index_t, 1>({-1});
|
||||
array<index_t, 1> up_vector_strides = make_array_with<index_t, 1>({-1});
|
||||
array<index_t, 1> up_vector_lengths{-1};
|
||||
array<index_t, 1> up_vector_strides{-1};
|
||||
|
||||
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
|
||||
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
|
||||
@@ -821,8 +821,8 @@ struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
|
||||
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
|
||||
const LowVectorStrides& low_vector_strides)
|
||||
{
|
||||
array<index_t, 1> up_vector_lengths = make_array_with<index_t, 1>({-1});
|
||||
array<index_t, 1> up_vector_strides = make_array_with<index_t, 1>({-1});
|
||||
array<index_t, 1> up_vector_lengths{-1};
|
||||
array<index_t, 1> up_vector_strides{-1};
|
||||
|
||||
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
|
||||
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
|
||||
@@ -940,8 +940,8 @@ struct unmerge : public base_transform<1, UpLengths::size()>
|
||||
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
|
||||
const LowVectorStrides& low_vector_strides)
|
||||
{
|
||||
array<index_t, NDimUp> up_vector_lengths = make_array_with<index_t, NDimUp>({-1});
|
||||
array<index_t, NDimUp> up_vector_strides = make_array_with<index_t, NDimUp>({-1});
|
||||
array<index_t, NDimUp> up_vector_lengths{-1};
|
||||
array<index_t, NDimUp> up_vector_strides{-1};
|
||||
|
||||
constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
|
||||
|
||||
|
||||
@@ -414,7 +414,7 @@ struct buffer_store_if<8>
|
||||
static_assert(sizeof(T) == 8);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
// TODO: ugly. rocm-6.0/6.1 seems neet bit_cast to same base type to avoid scratch
|
||||
using mbuf_t = ext_vector_t<typename T::value_type::raw_type, T::size()>;
|
||||
using mbuf_t = ext_vector_t<typename T::value_type, T::size()>;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %5\n"
|
||||
"buffer_store_dwordx2 %0, %1, %2, %3 offen offset:%4\n"
|
||||
"s_mov_b64 exec %6"
|
||||
@@ -1778,7 +1778,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
thread_buffer<T, N> tmp =
|
||||
amd_buffer_load_impl<T, N, coherence>(src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
if constexpr(oob_conditional_check)
|
||||
return src_thread_element_valid ? tmp : thread_buffer<T, N>{0};
|
||||
return src_thread_element_valid ? tmp : thread_buffer<T, N>{numeric<T>::zero()};
|
||||
else
|
||||
return tmp;
|
||||
#endif
|
||||
|
||||
@@ -20,6 +20,10 @@
|
||||
#define CK_TILE_DEVICE_EXTERN
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
|
||||
#endif
|
||||
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
|
||||
@@ -19,6 +19,8 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
return make_tuple(ts...);
|
||||
}
|
||||
#else
|
||||
|
||||
#if 0
|
||||
template <typename T, index_t N>
|
||||
using thread_buffer = array<T, N>;
|
||||
|
||||
@@ -27,6 +29,72 @@ CK_TILE_HOST_DEVICE constexpr auto make_thread_buffer(Ts&&... ts)
|
||||
{
|
||||
return make_array(ts...);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
template<typename T_, index_t N_>
|
||||
struct thread_buffer {
|
||||
using value_type = remove_cvref_t<T_>;
|
||||
static constexpr index_t N = N_;
|
||||
|
||||
value_type data[N];
|
||||
|
||||
// TODO: this ctor can't ignore
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer() : data{} {}
|
||||
CK_TILE_HOST_DEVICE constexpr thread_buffer(const value_type & o) : data{o} {}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto size() { return N; }
|
||||
CK_TILE_HOST_DEVICE auto & get() {return data; }
|
||||
CK_TILE_HOST_DEVICE const auto & get() const {return data; }
|
||||
CK_TILE_HOST_DEVICE auto & get(index_t i) {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE const auto & get(index_t i) const {return data[i]; }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& operator[](index_t i) const { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator[](index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr auto& operator()(index_t i) { return get(i); } // TODO: compatible
|
||||
CK_TILE_HOST_DEVICE constexpr auto& at(index_t i) { return get(i); }
|
||||
CK_TILE_HOST_DEVICE constexpr const auto& at(index_t i) const { return get(i); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at() { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
#define TB_COMMON_AS() \
|
||||
static_assert(sizeof(value_type) * N % sizeof(Tx) == 0); \
|
||||
constexpr int vx = sizeof(value_type) * N / sizeof(Tx)
|
||||
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as() {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE const auto & get_as() const {TB_COMMON_AS();
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE auto & get_as(index_t i) {TB_COMMON_AS();
|
||||
return reinterpret_cast<thread_buffer<Tx, vx>&>(data).get(i);}
|
||||
template<typename Tx>
|
||||
CK_TILE_HOST_DEVICE const auto & get_as(index_t i) const {TB_COMMON_AS();
|
||||
return reinterpret_cast<const thread_buffer<Tx, vx>&>(data).get(i);}
|
||||
|
||||
template <typename Tx> CK_TILE_HOST_DEVICE constexpr void set_as(index_t i, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(i) = x; }
|
||||
template <typename Tx, index_t I> CK_TILE_HOST_DEVICE constexpr void set_as(number<I>, const Tx & x)
|
||||
{ TB_COMMON_AS(); reinterpret_cast<array<Tx, vx>&>(data).at(number<I>{}) = x; }
|
||||
#undef TB_COMMON_AS
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
template <typename>
|
||||
struct vector_traits;
|
||||
|
||||
// specialization for array
|
||||
template <typename T, index_t N>
|
||||
struct vector_traits<thread_buffer<T, N>>
|
||||
{
|
||||
using scalar_type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -3,10 +3,9 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -22,18 +21,19 @@ enum class bf16_rounding_mode
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x);
|
||||
constexpr float bf16_to_float_raw(uint16_t x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x);
|
||||
constexpr double bf16_to_double_raw(uint16_t x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use __hip_bfloat16 as struct
|
||||
struct alignas(2) bfloat16_t
|
||||
{
|
||||
@@ -89,13 +89,24 @@ struct alignas(2) bfloat16_t
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<bfloat16_t>
|
||||
{
|
||||
using type = ushort;
|
||||
};
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
// round to nearest
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_rtn_raw(float f)
|
||||
constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
@@ -139,7 +150,7 @@ uint16_t float_to_bf16_rtn_raw(float f)
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
@@ -151,7 +162,7 @@ uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
|
||||
// Fast truncate instead of rounding, RTZ
|
||||
CK_TILE_HOST_DEVICE
|
||||
uint16_t float_to_bf16_truc_raw(float f)
|
||||
constexpr uint16_t float_to_bf16_truc_raw(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
@@ -162,7 +173,7 @@ uint16_t float_to_bf16_truc_raw(float f)
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
@@ -173,13 +184,13 @@ CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant<rounding>)
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding>)
|
||||
{
|
||||
return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float_raw(uint16_t x)
|
||||
constexpr float bf16_to_float_raw(uint16_t x)
|
||||
{
|
||||
union
|
||||
{
|
||||
@@ -190,100 +201,118 @@ float bf16_to_float_raw(uint16_t x)
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double_raw(uint16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant<rounding>)
|
||||
constexpr double bf16_to_double_raw(uint16_t x)
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
return static_cast<double>(bf16_to_float_raw(x));
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float bf16_to_float(bfloat16_t x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double bf16_to_double(bfloat16_t x) { return static_cast<double>(x); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t double_to_bf16(double f, constant<rounding> = {})
|
||||
{
|
||||
return bfloat16_t::bit_cast(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t bf16_to_fp16(bfloat16_t x) { return float_to_fp16(static_cast<float>(x)); }
|
||||
constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE bfloat16_t constexpr fp16_to_bf16(half_t f, constant<rounding> = {})
|
||||
{
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bfloat16_t>
|
||||
struct numeric<bfloat16_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min() { return bfloat16_t::bit_cast(0x0080); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t lowest()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0xff7f);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max() { return bfloat16_t::bit_cast(0x7f7f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t epsilon()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x1000);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return bfloat16_t(0.5f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t round_error() { return float_to_bf16(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t infinity()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7f80);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t quiet_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t signaling_NaN()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x7FFF);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t denorm_min()
|
||||
{
|
||||
return bfloat16_t::bit_cast(0x0001);
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr bfloat16_t zero()
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bfloat16_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
bfloat16_t abs(const bfloat16_t& x) { return bfloat16_t::bit_cast(x.get() & 0x7fff); }
|
||||
bfloat16_t abs(const bfloat16_t& x)
|
||||
{
|
||||
return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bfloat16_t& x)
|
||||
{
|
||||
uint16_t xx = x.get();
|
||||
uint16_t xx = bit_cast<bf16_raw_t>(x);
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -43,14 +42,15 @@ enum class fp8_rounding_mode
|
||||
*/
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_fp8_raw(float, constant<rounding> = {});
|
||||
|
||||
template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_bf8_raw(float, constant<rounding> = {});
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(uint8_t);
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(uint8_t);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
struct alignas(1) float8_e4m3_t
|
||||
{
|
||||
static constexpr int exponent = 4;
|
||||
@@ -58,7 +58,7 @@ struct alignas(1) float8_e4m3_t
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
@@ -116,7 +116,7 @@ struct alignas(1) float8_e5m2_t
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 1 << (exponent - 1); // NANOO
|
||||
#else
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
static constexpr int bias = (1 << (exponent - 1)) - 1; // IEEE
|
||||
#endif
|
||||
using raw_type = uint8_t;
|
||||
raw_type data;
|
||||
@@ -167,6 +167,28 @@ struct alignas(1) float8_e5m2_t
|
||||
using bf8_t = float8_e5m2_t;
|
||||
using bf8_raw_t = typename bf8_t::raw_type;
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<fp8_t>
|
||||
{
|
||||
using type = _BitInt(8);
|
||||
};
|
||||
|
||||
template <>
|
||||
struct native_t<bf8_t>
|
||||
{
|
||||
using type = unsigned _BitInt(8);
|
||||
};
|
||||
|
||||
#else
|
||||
using fp8_t = _BitInt(8);
|
||||
using fp8_raw_t = uint8_t;
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
using bf8_raw_t = uint8_t;
|
||||
#endif
|
||||
|
||||
// below is sw fp8 conversion, not utilizing hw instruction
|
||||
namespace impl {
|
||||
|
||||
@@ -174,29 +196,35 @@ template <typename X, typename Y, bool negative_zero_nan, bool clip, bool stoch>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_to_f8(X x, uint32_t rng)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
constexpr int out_exp = numeric_traits<Y>::exp;
|
||||
constexpr int out_mant = numeric_traits<Y>::mant;
|
||||
|
||||
// original type exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
constexpr int in_exp = numeric_traits<X>::exp;
|
||||
constexpr int in_mant = numeric_traits<X>::mant;
|
||||
|
||||
int exponent, bias;
|
||||
uint32_t head, mantissa, sign;
|
||||
// nan code is same for float and half
|
||||
constexpr Y nan_code = __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
constexpr uint32_t nan_mask = numeric_utils<X>::nan_mask;
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
constexpr Y nan_code =
|
||||
numeric<Y>::quiet_NaN(); // __builtin_bit_cast(Y, static_cast<uint8_t>(0x80));
|
||||
#else
|
||||
constexpr Y nan_code = 0x80;
|
||||
#endif
|
||||
|
||||
constexpr uint32_t nan_mask = numeric_traits<X>::nan_mask;
|
||||
|
||||
// convert to bitwise
|
||||
using T_bitwise = typename numeric_utils<X>::bitwise_type;
|
||||
using T_bitwise = typename numeric_traits<X>::bitwise_type;
|
||||
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
|
||||
|
||||
// unpack the input, depends on datatype
|
||||
head = x_bitwise & numeric_utils<X>::head_mask;
|
||||
mantissa = x_bitwise & numeric_utils<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & numeric_utils<X>::exp_mask;
|
||||
head = x_bitwise & numeric_traits<X>::head_mask;
|
||||
mantissa = x_bitwise & numeric_traits<X>::mant_mask;
|
||||
exponent = (head >> in_mant) & numeric_traits<X>::exp_mask;
|
||||
sign = head >> (in_exp + in_mant);
|
||||
bias = numeric_utils<X>::bias;
|
||||
bias = numeric_traits<X>::bias;
|
||||
|
||||
uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
|
||||
uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1;
|
||||
@@ -335,23 +363,23 @@ template <typename X, typename Y, bool negative_zero_nan>
|
||||
CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
{
|
||||
// fp8/bf8 exponent/mantissa layout
|
||||
constexpr int in_exp = numeric_utils<X>::exp;
|
||||
constexpr int in_mant = numeric_utils<X>::mant;
|
||||
constexpr int in_exp = numeric_traits<X>::exp;
|
||||
constexpr int in_mant = numeric_traits<X>::mant;
|
||||
|
||||
// resulting type exponent/mantissa layout
|
||||
constexpr int out_exp = numeric_utils<Y>::exp;
|
||||
constexpr int out_mant = numeric_utils<Y>::mant;
|
||||
constexpr int out_exp = numeric_traits<Y>::exp;
|
||||
constexpr int out_mant = numeric_traits<Y>::mant;
|
||||
uint8_t x_raw = __builtin_bit_cast(uint8_t, x);
|
||||
|
||||
// prepare the codes
|
||||
constexpr uint8_t nan_code = 0x80;
|
||||
Y Inf, NegInf, NaN, Neg0;
|
||||
using T_bitwise = typename numeric_utils<Y>::bitwise_type;
|
||||
using T_bitwise = typename numeric_traits<Y>::bitwise_type;
|
||||
|
||||
constexpr T_bitwise Inf_bitwise = numeric_utils<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = numeric_utils<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = numeric_utils<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = numeric_utils<Y>::Neg0;
|
||||
constexpr T_bitwise Inf_bitwise = numeric_traits<Y>::Inf;
|
||||
constexpr T_bitwise NegInf_bitwise = numeric_traits<Y>::NegInf;
|
||||
constexpr T_bitwise NaN_bitwise = numeric_traits<Y>::NaN;
|
||||
constexpr T_bitwise Neg0_bitwise = numeric_traits<Y>::Neg0;
|
||||
|
||||
Inf = *(reinterpret_cast<const Y*>(&Inf_bitwise));
|
||||
NegInf = *(reinterpret_cast<const Y*>(&NegInf_bitwise));
|
||||
@@ -384,7 +412,7 @@ CK_TILE_HOST_DEVICE Y run_cast_from_f8(X x)
|
||||
return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
|
||||
}
|
||||
|
||||
if((numeric_utils<Y>::mant == 10) && (numeric_utils<X>::mant == 2) && !negative_zero_nan)
|
||||
if((numeric_traits<Y>::mant == 10) && (numeric_traits<X>::mant == 2) && !negative_zero_nan)
|
||||
{
|
||||
retval = x_raw;
|
||||
retval <<= 8;
|
||||
@@ -553,7 +581,7 @@ CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_rtn_raw(float x)
|
||||
|
||||
// clang-format off
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_fp8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_fp8_sr_raw(x);
|
||||
@@ -561,14 +589,14 @@ CK_TILE_HOST_DEVICE fp8_raw_t float_to_fp8_raw(float x, constant<rounding>)
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
CK_TILE_HOST_DEVICE constexpr bf8_raw_t float_to_bf8_raw(float x, constant<rounding>)
|
||||
{
|
||||
if constexpr (rounding == fp8_rounding_mode::standard) return float_to_bf8_rtn_raw(x);
|
||||
else if constexpr (rounding == fp8_rounding_mode::stochastic) return float_to_bf8_sr_raw(x);
|
||||
else return bf8_raw_t{0};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float_raw(fp8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -578,11 +606,11 @@ CK_TILE_HOST_DEVICE float fp8_to_float_raw(fp8_raw_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(fp8_t::bit_cast(x));
|
||||
return impl::cast_from_f8<fp8_t, float, negative_zero_nan>(bit_cast<fp8_t>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float_raw(bf8_raw_t x)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
float fval;
|
||||
@@ -592,129 +620,200 @@ CK_TILE_HOST_DEVICE float bf8_to_float_raw(bf8_raw_t x)
|
||||
return fval;
|
||||
#else
|
||||
constexpr bool negative_zero_nan = true;
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bf8_t::bit_cast(x));
|
||||
return impl::cast_from_f8<bf8_t, float, negative_zero_nan>(bit_cast<bf8_t>(x));
|
||||
#endif
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE float8_e4m3_t float_to_fp8(float x, constant<rounding>)
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr fp8_t float_to_fp8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e4m3_t::bit_cast(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
template<fp8_rounding_mode rounding>
|
||||
CK_TILE_HOST_DEVICE float8_e5m2_t float_to_bf8(float x, constant<rounding>)
|
||||
template<fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bf8_t float_to_bf8(float x, constant<rounding> = {})
|
||||
{
|
||||
return float8_e5m2_t::bit_cast(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float fp8_to_float(float8_e4m3_t x)
|
||||
CK_TILE_HOST_DEVICE constexpr float fp8_to_float(fp8_t x)
|
||||
{
|
||||
return fp8_to_float_raw(x.get());
|
||||
return fp8_to_float_raw(bit_cast<fp8_raw_t>(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE float bf8_to_float(float8_e5m2_t x)
|
||||
CK_TILE_HOST_DEVICE constexpr float bf8_to_float(bf8_t x)
|
||||
{
|
||||
return bf8_to_float_raw(x.get());
|
||||
return bf8_to_float_raw(bit_cast<bf8_raw_t>(x));
|
||||
}
|
||||
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<fp8_t>
|
||||
struct numeric_traits<fp8_t>
|
||||
{
|
||||
static constexpr int exp = fp8_t::exponent;
|
||||
static constexpr int mant = fp8_t::mantissa;
|
||||
static constexpr int bias = fp8_t::bias;
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 8;
|
||||
#else
|
||||
static constexpr int bias = 7;
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_utils<bf8_t>
|
||||
struct numeric_traits<bf8_t>
|
||||
{
|
||||
static constexpr int exp = bf8_t::exponent;
|
||||
static constexpr int mant = bf8_t::mantissa;
|
||||
static constexpr int bias = bf8_t::bias;
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
static constexpr int bias = 16;
|
||||
#else
|
||||
static constexpr int bias = 15; // IEEE
|
||||
#endif
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<fp8_t>
|
||||
struct numeric<fp8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t min() { return fp8_t::bit_cast(0x08); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t min()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest() { return fp8_t::bit_cast(0xff); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t max() { return fp8_t::bit_cast(0x7f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t max()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon() { return fp8_t::bit_cast(0x20); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return fp8_t(0.5f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t round_error() { return float_to_fp8(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity() { return fp8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN() { return fp8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return fp8_t::bit_cast(0x01); }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<bf8_t>
|
||||
struct numeric<bf8_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t min() { return bf8_t::bit_cast(0x04); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t min()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest() { return bf8_t::bit_cast(0xff); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t max() { return bf8_t::bit_cast(0x7f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t max()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon() { return bf8_t::bit_cast(0x34); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return bf8_t(0.5f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t round_error() { return float_to_bf8(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity() { return bf8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN() { return bf8_t::bit_cast(0x80); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min() { return bf8_t::bit_cast(0x01); }
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, fp8_t)
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, bf8_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp8_t abs(const fp8_t& x) { return fp8_t::bit_cast(x.get() & 0x7f); }
|
||||
fp8_t abs(const fp8_t& x)
|
||||
{
|
||||
return bit_cast<fp8_t>(static_cast<fp8_raw_t>(bit_cast<fp8_raw_t>(x) & 0x7f));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const fp8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
uint8_t xx = bit_cast<fp8_raw_t>(x);
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
@@ -731,12 +830,15 @@ CK_TILE_DEVICE
|
||||
fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bf8_t abs(const bf8_t& x) { return bf8_t::bit_cast(x.get() & 0x7f); }
|
||||
bf8_t abs(const bf8_t& x)
|
||||
{
|
||||
return bit_cast<bf8_t>(static_cast<fp8_raw_t>(bit_cast<bf8_raw_t>(x) & 0x7f));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const bf8_t& x)
|
||||
{
|
||||
uint8_t xx = x.get();
|
||||
uint8_t xx = bit_cast<bf8_raw_t>(x);
|
||||
return xx == 0x80; // TODO: NANOO
|
||||
}
|
||||
|
||||
|
||||
@@ -2,33 +2,34 @@
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/arithmetic.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp16_hip_t = __half; // most of hip internal function use this type
|
||||
using fp16_hip_t = _Float16; // most of hip internal function use this type
|
||||
using fp16_raw_t = uint16_t;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x);
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x);
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// HIP use fp16_hip_t as interchangable data type for float16
|
||||
struct alignas(2) half_t
|
||||
{
|
||||
using raw_type = uint16_t;
|
||||
using raw_type = fp16_raw_t;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
@@ -83,6 +84,9 @@ struct alignas(2) half_t
|
||||
return static_cast<int>(fp16_to_float_hip(to_fp16()));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
explicit constexpr operator fp16_hip_t() const { return ck_tile::bit_cast<fp16_hip_t>(data); }
|
||||
|
||||
// internal access
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr raw_type& get() { return data; }
|
||||
@@ -91,86 +95,132 @@ struct alignas(2) half_t
|
||||
constexpr raw_type get() const { return data; }
|
||||
};
|
||||
|
||||
template <typename>
|
||||
struct native_t;
|
||||
|
||||
template <>
|
||||
struct native_t<half_t>
|
||||
{
|
||||
using type = _Float16;
|
||||
};
|
||||
|
||||
using fp16_t = half_t;
|
||||
using fp16_raw_t = typename fp16_t::raw_type;
|
||||
using fp16_raw_t = typename half_t::raw_type;
|
||||
#else
|
||||
using fp16_t = _Float16;
|
||||
using half_t = _Float16;
|
||||
using fp16_raw_t = ushort;
|
||||
#endif
|
||||
|
||||
// conversions
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
constexpr float fp16_to_float_hip(const fp16_hip_t& x)
|
||||
{
|
||||
// return __half2float(x);
|
||||
return static_cast<float>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
double fp16_to_double_hip(const fp16_hip_t& x) { return static_cast<double>(fp16_to_float_hip(x)); }
|
||||
constexpr double fp16_to_double_hip(const fp16_hip_t& x)
|
||||
{
|
||||
return static_cast<double>(fp16_to_float_hip(x));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
constexpr fp16_hip_t float_to_fp16_hip(const float& x)
|
||||
{
|
||||
return __float2half(x);
|
||||
// return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
constexpr fp16_hip_t double_to_fp16_hip(const double& x)
|
||||
{
|
||||
// return __float2half(x);
|
||||
return static_cast<fp16_hip_t>(x);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
constexpr float fp16_to_float(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
constexpr float fp16_to_double(const half_t& x) { return static_cast<float>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t float_to_fp16(const float& x) { return half_t{x}; }
|
||||
constexpr half_t float_to_fp16(const float& x) { return static_cast<half_t>(x); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t double_to_fp16(const double& x) { return half_t{x}; }
|
||||
constexpr half_t double_to_fp16(const double& x) { return static_cast<half_t>(x); }
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric_limits;
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric_limits<half_t>
|
||||
struct numeric<half_t>
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min() { return half_t::bit_cast(0x0400); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0400));
|
||||
}
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest() { return half_t::bit_cast(0xFBFF); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t lowest()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0xFBFF));
|
||||
}
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max() { return half_t::bit_cast(0x7BFF); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t max()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7BFF));
|
||||
}
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon() { return half_t::bit_cast(0x1800); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t epsilon()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x1800));
|
||||
}
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return half_t(0.5f); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t round_error() { return static_cast<half_t>(0.5f); }
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity() { return half_t::bit_cast(0x7C00); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t infinity()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7C00));
|
||||
}
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t quiet_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN() { return half_t::bit_cast(0x7FFF); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t signaling_NaN()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x7FFF));
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min() { return half_t::bit_cast(0x0001); }
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t denorm_min()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0x0001));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr half_t zero()
|
||||
{
|
||||
return bit_cast<half_t>(static_cast<fp16_raw_t>(0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<half_t>
|
||||
struct numeric_traits<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
@@ -186,9 +236,12 @@ struct numeric_utils<half_t>
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
// arithmetic
|
||||
CK_TILE_DEVICE
|
||||
bool operator==(const half_t& x, const half_t& y) { return __heq(x.to_fp16(), y.to_fp16()); }
|
||||
CK_TILE_DEVICE bool operator==(const half_t& x, const half_t& y)
|
||||
{
|
||||
return __heq(x.to_fp16(), y.to_fp16());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE
|
||||
bool operator!=(const half_t& x, const half_t& y) { return __hne(x.to_fp16(), y.to_fp16()); }
|
||||
@@ -205,6 +258,7 @@ bool operator>(const half_t& x, const half_t& y) { return __hgt(x.to_fp16(), y.t
|
||||
CK_TILE_DEVICE
|
||||
bool operator>=(const half_t& x, const half_t& y) { return __hge(x.to_fp16(), y.to_fp16()); }
|
||||
|
||||
#if 0
|
||||
CK_TILE_DEVICE
|
||||
half_t operator+(const half_t& x, const half_t& y)
|
||||
{
|
||||
@@ -289,12 +343,15 @@ half_t operator--(half_t& x, int)
|
||||
x = half_t(__hsub(x.to_fp16(), half_t(1.0f).to_fp16()));
|
||||
return y;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST, half_t)
|
||||
#endif
|
||||
|
||||
// math
|
||||
CK_TILE_HOST_DEVICE
|
||||
half_t abs(const half_t& x) { return half_t::bit_cast(x.get() & 0x7fff); }
|
||||
half_t abs(const half_t& x) { return bit_cast<half_t>(x.get() & 0x7fff); }
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
bool isnan(const half_t& x)
|
||||
@@ -317,5 +374,5 @@ half_t exp2(half_t x) { return static_cast<half_t>(exp2f(static_cast<float>(x)))
|
||||
|
||||
CK_TILE_DEVICE
|
||||
half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
#endif
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,9 +1,103 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this struct has the information of
|
||||
// 1. limit of a certain type, simliar to std::numeric_limits
|
||||
// 2. some pre-defined value, zero, one...
|
||||
//
|
||||
template <typename T>
|
||||
struct numeric
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T zero() { return static_cast<T>(0); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T one() { return static_cast<T>(1); }
|
||||
|
||||
#ifndef C_LOG2E
|
||||
#define C_LOG2E 1.44269504088896340736 // log2(e)
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr T log2e()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, float> || std::is_same_v<T, double>)
|
||||
{
|
||||
return static_cast<T>(C_LOG2E);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0; // TODO: integer?
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_traits;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
@@ -13,13 +13,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
#if CK_TILE_USE_CUSTOM_DATA_TYPE
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr remove_cvref_t<Y> type_convert(const X& x)
|
||||
{
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
#if 0
|
||||
#else
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
@@ -43,22 +43,22 @@ CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
|
||||
return static_cast<Y>(type_convert<non_const_y, non_const_x>(x));
|
||||
}
|
||||
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, stype_) \
|
||||
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return stype_##_to_##dtype_(x); \
|
||||
return sname_##_to_##dname_(x); \
|
||||
}
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, fp16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf16_t)
|
||||
CK_TILE_TYPE_CONVERT(float, fp8_t)
|
||||
CK_TILE_TYPE_CONVERT(float, bf8_t)
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp16_t, fp16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(float, float, fp8_t, fp8)
|
||||
CK_TILE_TYPE_CONVERT(float, float, bf8_t, bf8)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, float)
|
||||
CK_TILE_TYPE_CONVERT(fp16_t, fp16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
|
||||
CK_TILE_TYPE_CONVERT(fp8_t, fp8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
@@ -14,6 +14,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this structure is used to pick up the <base> type inside
|
||||
// using xxx = <base> __attribute__((ext_vector_type(N)));
|
||||
// because clang only allow native type + bool in this term (custom type will fail)
|
||||
// overload this structure to let proper <base> type
|
||||
|
||||
template <typename T>
|
||||
struct native_t
|
||||
{
|
||||
using type = remove_cvref_t<T>;
|
||||
};
|
||||
|
||||
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
|
||||
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
|
||||
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
|
||||
@@ -23,7 +34,7 @@ template <typename T_, index_t N_>
|
||||
struct ext_vector
|
||||
{
|
||||
static constexpr index_t N = N_;
|
||||
using value_type = T_;
|
||||
using value_type = typename native_t<remove_cvref_t<T_>>::type;
|
||||
static_assert(!std::is_class_v<value_type>);
|
||||
using type = value_type __attribute__((ext_vector_type(N))); // this is danguous
|
||||
};
|
||||
@@ -52,10 +63,12 @@ struct vector_traits<T __attribute__((ext_vector_type(N)))>
|
||||
// below are some pre-defines of ext_vector_type
|
||||
// attention! 2 vector type could be just the same type
|
||||
// fp64
|
||||
using fp64_t = double;
|
||||
using fp64x2_t = double __attribute__((ext_vector_type(2)));
|
||||
using fp64x4_t = double __attribute__((ext_vector_type(4)));
|
||||
|
||||
// fp32
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp32x4_t = float __attribute__((ext_vector_type(4)));
|
||||
using fp32x8_t = float __attribute__((ext_vector_type(8)));
|
||||
@@ -64,6 +77,7 @@ using fp32x32_t = float __attribute__((ext_vector_type(32)));
|
||||
using fp32x64_t = float __attribute__((ext_vector_type(64)));
|
||||
|
||||
// fp16
|
||||
// using fp16_t = ...
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp16x4_t = _Float16 __attribute__((ext_vector_type(4)));
|
||||
using fp16x8_t = _Float16 __attribute__((ext_vector_type(8)));
|
||||
@@ -71,7 +85,8 @@ using fp16x16_t = _Float16 __attribute__((ext_vector_type(16)));
|
||||
using fp16x32_t = _Float16 __attribute__((ext_vector_type(32)));
|
||||
using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bfp16
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
@@ -80,6 +95,7 @@ using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
using int32x2_t = int32_t __attribute__((ext_vector_type(2)));
|
||||
using int32x4_t = int32_t __attribute__((ext_vector_type(4)));
|
||||
using int32x8_t = int32_t __attribute__((ext_vector_type(8)));
|
||||
@@ -88,6 +104,7 @@ using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
|
||||
using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i16
|
||||
// using int16_t = ...
|
||||
using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
|
||||
using int16x4_t = int16_t __attribute__((ext_vector_type(4)));
|
||||
using int16x8_t = int16_t __attribute__((ext_vector_type(8)));
|
||||
@@ -96,6 +113,7 @@ using int16x32_t = int16_t __attribute__((ext_vector_type(32)));
|
||||
using int16x64_t = int16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// u16
|
||||
// using uint16_t
|
||||
using uint16x2_t = uint16_t __attribute__((ext_vector_type(2)));
|
||||
using uint16x4_t = uint16_t __attribute__((ext_vector_type(4)));
|
||||
using uint16x8_t = uint16_t __attribute__((ext_vector_type(8)));
|
||||
@@ -104,6 +122,7 @@ using uint16x32_t = uint16_t __attribute__((ext_vector_type(32)));
|
||||
using uint16x64_t = uint16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i8
|
||||
// using int8_t
|
||||
using int8x2_t = int8_t __attribute((ext_vector_type(2)));
|
||||
using int8x4_t = int8_t __attribute((ext_vector_type(4)));
|
||||
using int8x8_t = int8_t __attribute((ext_vector_type(8)));
|
||||
@@ -112,6 +131,7 @@ using int8x32_t = int8_t __attribute((ext_vector_type(32)));
|
||||
using int8x64_t = int8_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// f8
|
||||
// using fp8_t
|
||||
using fp8x2_t = fp8_raw_t __attribute((ext_vector_type(2)));
|
||||
using fp8x4_t = fp8_raw_t __attribute((ext_vector_type(4)));
|
||||
using fp8x8_t = fp8_raw_t __attribute((ext_vector_type(8)));
|
||||
@@ -120,6 +140,7 @@ using fp8x32_t = fp8_raw_t __attribute((ext_vector_type(32)));
|
||||
using fp8x64_t = fp8_raw_t __attribute((ext_vector_type(64)));
|
||||
|
||||
// bf8
|
||||
// using bf8_t
|
||||
using bf8x2_t = bf8_raw_t __attribute((ext_vector_type(2)));
|
||||
using bf8x4_t = bf8_raw_t __attribute((ext_vector_type(4)));
|
||||
using bf8x8_t = bf8_raw_t __attribute((ext_vector_type(8)));
|
||||
|
||||
@@ -115,7 +115,7 @@ struct buffer_view<address_space_enum::generic,
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{0};
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -319,7 +319,7 @@ struct buffer_view<address_space_enum::global,
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{0};
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -666,14 +666,18 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
return tmp;
|
||||
#else
|
||||
return *c_style_pointer_cast<const X*>(&p_data_[i]);
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
// using buf_t = ushort __attribute__((ext_vector_type(8)));
|
||||
auto rtn = *c_style_pointer_cast<const buf_t*>(&p_data_[i]);
|
||||
return bit_cast<X>(rtn);
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{0};
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -829,7 +833,10 @@ struct buffer_view<address_space_enum::lds,
|
||||
|
||||
__builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
|
||||
#else
|
||||
*c_style_pointer_cast<X*>(&p_data_[i]) = x;
|
||||
using buf_t = ext_vector_t<typename vector_traits<remove_cvref_t<T>>::scalar_type,
|
||||
scalar_per_t_vector * scalar_per_x_vector>;
|
||||
|
||||
*c_style_pointer_cast<buf_t*>(&p_data_[i]) = reinterpret_cast<const buf_t&>(x);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -948,7 +955,7 @@ struct buffer_view<address_space_enum::vgpr,
|
||||
{
|
||||
if constexpr(InvalidElementUseNumericalZeroValue)
|
||||
{
|
||||
return X{0};
|
||||
return X{numeric<remove_cvref_t<T>>::zero()};
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#include "ck_tile/core/container/container_helper.hpp"
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/core/utility/limits.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -511,7 +511,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id_ = numeric_limits<index_t>::min();
|
||||
index_t adaptor0_max_hidden_id_ = numeric<index_t>::min();
|
||||
|
||||
static_for<0, TensorAdaptor0::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
@@ -537,7 +537,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id_ = numeric_limits<index_t>::max();
|
||||
index_t adaptor1_min_hidden_id_ = numeric<index_t>::max();
|
||||
|
||||
static_for<0, TensorAdaptor1::get_num_of_transform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
|
||||
@@ -40,7 +40,8 @@ template <typename InElementFunc,
|
||||
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
const InTensor&... in_dstr_tensors)
|
||||
{
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
|
||||
using OutNativeType = typename native_t<OutDataType>::type;
|
||||
|
||||
// TODO: make sure all distributed tensors have same lengths and distribution
|
||||
// static_assert(xxx);
|
||||
@@ -53,7 +54,7 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
|
||||
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
out_dstr_tensor.get_thread_buffer()(i) =
|
||||
in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
|
||||
static_cast<OutNativeType>(in_element_func(in_dstr_tensors.get_thread_buffer()[i]...));
|
||||
});
|
||||
|
||||
return out_dstr_tensor;
|
||||
|
||||
@@ -303,7 +303,7 @@ struct tile_window_with_static_distribution
|
||||
const vector_t vec_value =
|
||||
get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
|
||||
bottom_tensor_thread_coord, bool_constant<oob_conditional_check>{});
|
||||
|
||||
#if 1
|
||||
// write into distributed tensor
|
||||
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
|
||||
constexpr auto idx_ys = generate_array(
|
||||
@@ -319,7 +319,14 @@ struct tile_window_with_static_distribution
|
||||
dst_tensor.get_thread_buffer().template at<d>() =
|
||||
vec_value.template get_as<DataType>()[j];
|
||||
});
|
||||
#else
|
||||
constexpr index_t d =
|
||||
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
|
||||
static_assert(d % Traits::ScalarPerVector == 0);
|
||||
|
||||
dst_tensor.get_thread_buffer().template get_as<vector_t>()(
|
||||
number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
|
||||
#endif
|
||||
// move thread coordinate
|
||||
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
|
||||
{
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include <limits>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
struct numeric_limits
|
||||
{
|
||||
// minimum finite value, or minimum positive normalized value for float
|
||||
CK_TILE_HOST_DEVICE static constexpr T min() { return std::numeric_limits<T>::min(); }
|
||||
|
||||
// minumum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
|
||||
// maximum finite value
|
||||
CK_TILE_HOST_DEVICE static constexpr T max() { return std::numeric_limits<T>::max(); }
|
||||
|
||||
// difference between 1.0 and next value representable by float
|
||||
CK_TILE_HOST_DEVICE static constexpr T epsilon() { return std::numeric_limits<T>::epsilon(); }
|
||||
|
||||
// maximum rounding error
|
||||
CK_TILE_HOST_DEVICE static constexpr T round_error()
|
||||
{
|
||||
return std::numeric_limits<T>::round_error();
|
||||
}
|
||||
|
||||
// positive infinity value
|
||||
CK_TILE_HOST_DEVICE static constexpr T infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
|
||||
// quiet NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T quiet_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
|
||||
// signaling NaN
|
||||
CK_TILE_HOST_DEVICE static constexpr T signaling_NaN()
|
||||
{
|
||||
return std::numeric_limits<T>::signaling_NaN();
|
||||
}
|
||||
|
||||
// smallest positive subnormal value
|
||||
CK_TILE_HOST_DEVICE static constexpr T denorm_min()
|
||||
{
|
||||
return std::numeric_limits<T>::denorm_min();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils;
|
||||
|
||||
template <>
|
||||
struct numeric_utils<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -21,7 +21,7 @@ CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, cons
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
if(mask.IsOutOfBound(m, n))
|
||||
c_b_m_n(batch, m, n) = -ck_tile::numeric_limits<CDataType>::infinity();
|
||||
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -18,7 +18,7 @@ CK_TILE_HOST void reference_batched_softmax(
|
||||
const int N = a_b_m_n.mDesc.get_lengths()[2];
|
||||
|
||||
auto f = [&](auto batch, auto m) {
|
||||
CompDataType v_max = -ck_tile::numeric_limits<CompDataType>::infinity();
|
||||
CompDataType v_max = -ck_tile::numeric<CompDataType>::infinity();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
|
||||
@@ -16,7 +16,7 @@ CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
|
||||
auto f = [&](auto m) {
|
||||
const int N = a_m_n.mDesc.get_lengths()[1];
|
||||
|
||||
AccDataType v_max = ck_tile::numeric_limits<ADataType>::Lowest();
|
||||
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
|
||||
|
||||
// max
|
||||
for(int n = 0; n < N; ++n)
|
||||
|
||||
@@ -50,7 +50,7 @@ struct FmhaFwdKernel
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
|
||||
template <> struct t2s<ck_tile::half_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
|
||||
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
|
||||
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
|
||||
|
||||
@@ -190,7 +190,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -209,7 +209,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -347,15 +347,12 @@ struct BlockFmhaPipelineQRKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,7 +361,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -379,7 +376,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -232,7 +232,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -252,7 +252,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -390,15 +390,12 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -407,7 +404,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -458,7 +455,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -182,7 +182,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -330,15 +330,12 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -347,7 +344,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -362,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSFp8
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
@@ -194,7 +194,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric_limits<SMPLComputeDataType>::infinity());
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -338,15 +338,12 @@ struct BlockFmhaPipelineQSKSVS
|
||||
number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity(),
|
||||
[&](auto tile_idx) {
|
||||
const auto row =
|
||||
q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col =
|
||||
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -355,7 +352,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
s,
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric_limits<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
@@ -370,7 +367,7 @@ struct BlockFmhaPipelineQSKSVS
|
||||
/// consideration
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_m == -numeric_limits<SMPLComputeDataType>::infinity()
|
||||
return raw_m == -numeric<SMPLComputeDataType>::infinity()
|
||||
? type_convert<SMPLComputeDataType>(0.f)
|
||||
: raw_m;
|
||||
}
|
||||
|
||||
@@ -29,4 +29,3 @@
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user