mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add e8m0 scaled convert into CK_TILE (#2617)
* first commit * remove redundent code * modify according to comments. * fix type_convert error with scaled_type_convert
This commit is contained in:
@@ -27,6 +27,7 @@
|
||||
#include "ck_tile/core/container/thread_buffer.hpp"
|
||||
#include "ck_tile/core/container/tuple.hpp"
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/e8m0.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
|
||||
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
102
include/ck_tile/core/numeric/e8m0.hpp
Normal file
@@ -0,0 +1,102 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Unsigned representation of a conventional biased Float32 exponent.
|
||||
*
|
||||
* bias = 127;
|
||||
*
|
||||
* E8M0_1 = 0b01111111; => 2^(127-127) = 1
|
||||
* E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2
|
||||
* E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8
|
||||
* E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256
|
||||
* E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768
|
||||
* E8M0_MIN = 0b00000000; => 2^-127
|
||||
* E8M0_MAX = 0b11111110; => 2^127
|
||||
* E8M0_NAN = 0b11111111; => NaN
|
||||
*/
|
||||
|
||||
struct e8m0_bexp_t
|
||||
{
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale)
|
||||
: e8m0_bexp_t(static_cast<type>(numeric_utils<float>::get_exponent(scale)))
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
|
||||
constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; }
|
||||
|
||||
constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; }
|
||||
};
|
||||
|
||||
using e8m0_t = e8m0_bexp_t;
|
||||
using e8m0_raw_t = typename e8m0_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<e8m0_t>
|
||||
{
|
||||
using bitwise_type = e8m0_raw_t;
|
||||
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr int PackedSize = 1;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<e8m0_t>
|
||||
{
|
||||
static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127
|
||||
static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127
|
||||
static constexpr e8m0_raw_t binary_nan = 0b11111111;
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; }
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); }
|
||||
CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); }
|
||||
};
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const
|
||||
{
|
||||
using traits = numeric_traits<float>;
|
||||
if(data == numeric<e8m0_t>::binary_nan)
|
||||
{
|
||||
return traits::NaN;
|
||||
}
|
||||
else if(data == 0)
|
||||
{
|
||||
return std::numeric_limits<float>::min();
|
||||
}
|
||||
else
|
||||
{
|
||||
return bit_cast<float>(static_cast<traits::bitwise_type>(data) << traits::mant);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits<T>
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename T::raw_type;
|
||||
using raw_type = typename traits::bitwise_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr int get_exponent(raw_type x)
|
||||
static constexpr raw_type get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr raw_type get_exponent(const T& x)
|
||||
{
|
||||
return get_exponent(bit_cast<raw_type>(x));
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
@@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits<T>
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(uint32_t i = 0; i < traits::mant; ++i)
|
||||
for(raw_type i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
@@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits<T>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
static constexpr int e8m0_bias = 127; // TODO: make it generic.
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
using utils = numeric_utils<T>;
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
|
||||
return std::ldexp(sign * mant * scale, exp);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
value /= scale;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
@@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
static constexpr int exponent = 2;
|
||||
static constexpr int mantissa = 1;
|
||||
static constexpr int bias = 1;
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
@@ -41,18 +38,27 @@ struct pk_float4_e2m1_t
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f)
|
||||
: data{float_to_e2m1(init, scale)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const;
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); }
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); }
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
|
||||
@@ -193,131 +199,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data);
|
||||
return impl::_from_f4<bf16_t>(data, scale);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data);
|
||||
return impl::_from_f4<bf16x2_t>(data, scale);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x);
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0]), scale),
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
return impl::_to_f4(x, scale);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale));
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data);
|
||||
return impl::_from_f4<fp32_t>(data, scale);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data);
|
||||
return impl::_from_f4<fp32x2_t>(data, scale);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale)};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data);
|
||||
return impl::_from_f4<fp16_t>(data, scale);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data);
|
||||
return impl::_from_f4<fp16x2_t>(data, scale);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}), scale)),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{}), scale))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const
|
||||
{
|
||||
return e2m1_to_fp32_table[data & 0xf];
|
||||
return e2m1_to_fp32_table[unpack(number<0>{})] * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
|
||||
return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const
|
||||
{
|
||||
return e2m1_to_fp16_table[data & 0xf];
|
||||
return type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale;
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const
|
||||
{
|
||||
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
|
||||
return fp16x2_t{
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<0>{})]) * scale),
|
||||
type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)};
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale);
|
||||
|
||||
#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert<dtype_, stype_>(stype_ x, \
|
||||
float scale) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, scale); \
|
||||
} \
|
||||
template <> \
|
||||
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
|
||||
{ \
|
||||
return sname_##_to_##dname_(x, 1.f); \
|
||||
}
|
||||
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4)
|
||||
#undef CK_TILE_SCALED_TYPE_CONVERT
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -409,7 +409,13 @@ struct HostTensor
|
||||
}
|
||||
|
||||
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
|
||||
void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
|
||||
void SetZero()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, e8m0_t>)
|
||||
std::fill(mData.begin(), mData.end(), e8m0_t{1.f});
|
||||
else
|
||||
std::fill(mData.begin(), mData.end(), 0);
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp)
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp)
|
||||
add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp)
|
||||
endif()
|
||||
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
162
test/ck_tile/data_type/test_mx_scale.cpp
Normal file
@@ -0,0 +1,162 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using ck_tile::bf16_t;
|
||||
using ck_tile::bf16x2_t;
|
||||
using ck_tile::fp16_t;
|
||||
using ck_tile::fp16x2_t;
|
||||
using ck_tile::fp32_t;
|
||||
using ck_tile::fp32x2_t;
|
||||
using ck_tile::number;
|
||||
using ck_tile::pk_fp4_t;
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert();
|
||||
|
||||
using ck_tile::e8m0_raw_t;
|
||||
using ck_tile::e8m0_t;
|
||||
|
||||
TEST(OCP_Scale, NumericLimits)
|
||||
{
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::has_inf(), false);
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::zero(), ck_tile::numeric<e8m0_t>::signaling_NaN());
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::min(), e8m0_t{e8m0_raw_t{0b00000000}});
|
||||
EXPECT_EQ(ck_tile::numeric<e8m0_t>::max(), e8m0_t{e8m0_raw_t{0b11111110}});
|
||||
}
|
||||
TEST(OCP_Scale, NumericBasic)
|
||||
{
|
||||
auto scale_1 = e8m0_t{1.0f};
|
||||
auto scale_2 = e8m0_t{e8m0_raw_t{ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^0
|
||||
EXPECT_EQ(scale_1, scale_2);
|
||||
|
||||
auto scale_3 = e8m0_t{8.0f};
|
||||
auto scale_4 = e8m0_t{e8m0_raw_t{3 + ck_tile::numeric_traits<e8m0_t>::bias}}; // 2^3
|
||||
EXPECT_EQ(scale_3, scale_4);
|
||||
}
|
||||
|
||||
TEST(OCP_Scale, ScaledConvertDevice)
|
||||
{
|
||||
constexpr bool is_device = true;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, ScaledConvertHost)
|
||||
{
|
||||
constexpr bool is_device = false;
|
||||
test_convert<fp32_t, fp32_t, is_device>(); // fp32 -> fp4 -> fp32
|
||||
test_convert<fp16_t, fp16_t, is_device>();
|
||||
test_convert<bf16_t, bf16_t, is_device>();
|
||||
test_convert<fp32_t, fp16_t, is_device>();
|
||||
test_convert<fp32_t, bf16_t, is_device>();
|
||||
test_convert<fp16_t, fp32_t, is_device>();
|
||||
test_convert<bf16_t, fp32_t, is_device>();
|
||||
}
|
||||
TEST(OCP_Scale, tensorInit)
|
||||
{
|
||||
using scale_t = e8m0_t;
|
||||
ck_tile::HostTensor<scale_t> scales({10, 10});
|
||||
ck_tile::FillUniformDistribution<scale_t>{1.f, 1.f}(scales);
|
||||
scales.SetZero();
|
||||
}
|
||||
|
||||
#define toPF4(x, y) ck_tile::scaled_type_convert<pk_fp4_t>(x, y)
|
||||
#define toDST(x, y) ck_tile::scaled_type_convert<DST>(x, y)
|
||||
#define toDSTx2(x, y) ck_tile::scaled_type_convert<DSTx2_t>(x, y)
|
||||
|
||||
#define toF32(x) ck_tile::type_convert<float>(x)
|
||||
#define toPF4_(x) ck_tile::type_convert<pk_fp4_t>(x)
|
||||
#define toSRC(x) ck_tile::type_convert<SRC>(x)
|
||||
#define toDST_(x) ck_tile::type_convert<DST>(x)
|
||||
|
||||
template <typename Kernel, typename... Args>
|
||||
__global__ void MyKernel(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
template <typename SRC, typename DST, int N>
|
||||
struct SrcPkfp4Dst
|
||||
{
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const SRC* src, DST* dst, e8m0_t scale1, e8m0_t scale2) const
|
||||
{
|
||||
|
||||
using SRCx2_t = ck_tile::ext_vector_t<SRC, 2>;
|
||||
using DSTx2_t = ck_tile::ext_vector_t<DST, 2>;
|
||||
|
||||
ck_tile::static_for<0, N, 2>{}([&](auto i) {
|
||||
const auto input2 = SRCx2_t{src[i], src[i + 1]};
|
||||
|
||||
if(i % 4 == 0)
|
||||
{
|
||||
// ex: fp32_t -> fp4 -> bf16_t
|
||||
dst[i] = toDST(toPF4(src[i], scale1), scale2);
|
||||
// ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t
|
||||
dst[i + 1] = toDST(toPF4_(toPF4(input2, scale1).unpack(number<1>{})), scale2);
|
||||
}
|
||||
else
|
||||
{
|
||||
// ex: fp32x2_t -> pk_fp4_t -> bf16x2_t
|
||||
reinterpret_cast<DSTx2_t*>(dst)[i >> 1] = toDSTx2(toPF4(input2, scale1), scale2);
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SRC, typename DST, bool is_device>
|
||||
CK_TILE_HOST void test_convert()
|
||||
{
|
||||
const auto test_data = std::array{4.f, 6.f, 8.f, 10.f};
|
||||
const auto ref_data = std::array{8.f, 16.f, 16.f, 16.f};
|
||||
const auto scale1 = e8m0_t{8.0f};
|
||||
const auto scale2 = e8m0_t{16.0f};
|
||||
|
||||
static_assert(test_data.size() == ref_data.size());
|
||||
static_assert(test_data.size() % 2 == 0);
|
||||
|
||||
constexpr int N = test_data.size();
|
||||
std::array<SRC, N> in;
|
||||
std::array<DST, N> ref, out;
|
||||
|
||||
// prepare input and ground truth in host
|
||||
for(int i = 0; i < N; ++i)
|
||||
{
|
||||
in[i] = toSRC(test_data[i]);
|
||||
ref[i] = toDST_(ref_data[i]);
|
||||
EXPECT_EQ(test_data[i], toF32(in[i]));
|
||||
EXPECT_EQ(ref_data[i], toF32(ref[i]));
|
||||
}
|
||||
|
||||
using job = SrcPkfp4Dst<SRC, DST, N>;
|
||||
|
||||
if constexpr(is_device)
|
||||
{
|
||||
auto in_d = std::make_unique<ck_tile::DeviceMem>(in.size() * sizeof(SRC));
|
||||
auto out_d = std::make_unique<ck_tile::DeviceMem>(out.size() * sizeof(DST));
|
||||
in_d->ToDevice(in.data());
|
||||
|
||||
MyKernel<job><<<1, 1>>>(reinterpret_cast<const SRC*>(in_d->GetDeviceBuffer()),
|
||||
reinterpret_cast<DST*>(out_d->GetDeviceBuffer()),
|
||||
scale1,
|
||||
scale2);
|
||||
|
||||
out_d->FromDevice(out.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
job{}(in.data(), out.data(), scale1, scale2);
|
||||
}
|
||||
|
||||
for(int i = 0; i < N; ++i)
|
||||
EXPECT_EQ(ref[i], out[i]) << "i:" << i;
|
||||
}
|
||||
Reference in New Issue
Block a user