From ed2293a87b3b1869526bacd9ffe207dda18ed679 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Thu, 7 Aug 2025 21:37:28 +0800 Subject: [PATCH] Add e8m0 scaled convert into CK_TILE (#2617) * first commit * remove redundent code * modify according to comments. * fix type_convert error with scaled_type_convert [ROCm/composable_kernel commit: 5d6d236b255b4ef9c8f38e1bd35975acda0af19a] --- include/ck_tile/core.hpp | 1 + include/ck_tile/core/numeric/e8m0.hpp | 102 +++++++++++ include/ck_tile/core/numeric/mxfp_convert.hpp | 27 +-- include/ck_tile/core/numeric/pk_fp4.hpp | 163 +++++++++++------- include/ck_tile/core/numeric/type_convert.hpp | 41 +++-- include/ck_tile/host/host_tensor.hpp | 8 +- test/ck_tile/data_type/CMakeLists.txt | 1 + test/ck_tile/data_type/test_mx_scale.cpp | 162 +++++++++++++++++ 8 files changed, 419 insertions(+), 86 deletions(-) create mode 100644 include/ck_tile/core/numeric/e8m0.hpp create mode 100644 test/ck_tile/data_type/test_mx_scale.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index c8945f03e9..9f3c996873 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -27,6 +27,7 @@ #include "ck_tile/core/container/thread_buffer.hpp" #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/numeric/bfloat16.hpp" +#include "ck_tile/core/numeric/e8m0.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/half.hpp" #include "ck_tile/core/numeric/int8.hpp" diff --git a/include/ck_tile/core/numeric/e8m0.hpp b/include/ck_tile/core/numeric/e8m0.hpp new file mode 100644 index 0000000000..ea94880f27 --- /dev/null +++ b/include/ck_tile/core/numeric/e8m0.hpp @@ -0,0 +1,102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +namespace ck_tile { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ + +struct e8m0_bexp_t +{ + using raw_type = uint8_t; + using type = raw_type; + + raw_type data; + + CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t() : data{type{0b11111111}} {} + CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(type init) : data{init} {} + CK_TILE_HOST_DEVICE explicit constexpr e8m0_bexp_t(float scale) + : e8m0_bexp_t(static_cast(numeric_utils::get_exponent(scale))) + { + } + CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr operator float() const; + + constexpr bool operator==(const e8m0_bexp_t& other) const { return data == other.data; } + + constexpr bool operator!=(const e8m0_bexp_t& other) const { return data != other.data; } +}; + +using e8m0_t = e8m0_bexp_t; +using e8m0_raw_t = typename e8m0_t::raw_type; + +template <> +struct numeric_traits +{ + using bitwise_type = e8m0_raw_t; + + static constexpr int exp = 8; + static constexpr int mant = 0; + static constexpr int bias = 127; + static constexpr int PackedSize = 1; +}; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + static constexpr e8m0_raw_t binary_min = 0b00000000; // 2^-127 + static constexpr e8m0_raw_t binary_max = 0b11111110; // 2^127 + static constexpr e8m0_raw_t binary_nan = 0b11111111; + CK_TILE_HOST_DEVICE static constexpr e8m0_t min() { return e8m0_t{binary_min}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t max() { return e8m0_t{binary_max}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t quiet_NaN() { return e8m0_t{binary_nan}; } + CK_TILE_HOST_DEVICE static constexpr e8m0_t signaling_NaN() { return e8m0_t{binary_nan}; } + CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } + + CK_TILE_HOST_DEVICE static constexpr e8m0_t epsilon() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t round_error() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t zero() { return signaling_NaN(); } + CK_TILE_HOST_DEVICE static constexpr e8m0_t infinity() { return signaling_NaN(); } +}; + +CK_TILE_HOST_DEVICE constexpr e8m0_bexp_t::operator float() const +{ + using traits = numeric_traits; + if(data == numeric::binary_nan) + { + return traits::NaN; + } + else if(data == 0) + { + return std::numeric_limits::min(); + } + else + { + return bit_cast(static_cast(data) << traits::mant); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/mxfp_convert.hpp b/include/ck_tile/core/numeric/mxfp_convert.hpp index b2e138e880..9b378933d0 100644 --- a/include/ck_tile/core/numeric/mxfp_convert.hpp +++ b/include/ck_tile/core/numeric/mxfp_convert.hpp @@ -12,15 +12,19 @@ struct numeric_utils : numeric_traits using traits = numeric_traits; using _numeric = numeric; - using raw_type = typename T::raw_type; + using raw_type = typename traits::bitwise_type; static constexpr int exp_mask = (1 << traits::exp) - 1; - static constexpr int get_exponent(raw_type x) + static constexpr raw_type get_exponent(raw_type x) { // TODO: check if repeated calls are optimized. return (x >> traits::mant) & exp_mask; } + static constexpr raw_type get_exponent(const T& x) + { + return get_exponent(bit_cast(x)); + } static constexpr bool is_positive(raw_type x) { return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero; @@ -33,7 +37,7 @@ struct numeric_utils : numeric_traits static constexpr double get_mantissa(raw_type x) { double mantissa = is_subnormal(x) ? 0.0f : 1.0f; - for(uint32_t i = 0; i < traits::mant; ++i) + for(raw_type i = 0; i < traits::mant; ++i) { mantissa += std::ldexp(static_cast(x & 0b1), -(traits::mant - i)); x >>= 1; @@ -43,22 +47,23 @@ struct numeric_utils : numeric_traits }; template -CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127) +CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, float scale = 1.f) { - using utils = numeric_utils; - static constexpr int e8m0_bias = 127; // TODO: make it generic. - float sign = utils::is_positive(data) ? 1.0 : -1.0; - int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias; - float mant = utils::get_mantissa(data); + using utils = numeric_utils; + float sign = utils::is_positive(data) ? 1.0 : -1.0; + int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias; + float mant = utils::get_mantissa(data); - return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias); + return std::ldexp(sign * mant * scale, exp); } template -CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value) +CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value, float scale = 1.f) { using bitwise_type = typename numeric_traits::bitwise_type; + value /= scale; + if(std::abs(value) > float(numeric::max())) { float max_value = numeric::max(); diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index 0dee750b69..a345cd1b75 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,14 +23,11 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float); +CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f); // TODO: Add stochastic method struct pk_float4_e2m1_t { - static constexpr int exponent = 2; - static constexpr int mantissa = 1; - static constexpr int bias = 1; // TODO: Can we merge raw_type and type? using raw_type = uint8_t; using type = raw_type; @@ -41,18 +38,27 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast(init)} { } - CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)} + CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f) + : data{float_to_e2m1(init, scale)} { } CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } - CK_TILE_HOST_DEVICE constexpr operator float() const; - CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const; - CK_TILE_HOST_DEVICE constexpr operator fp16_t() const; - CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const; - CK_TILE_HOST_DEVICE constexpr operator bf16_t() const; - CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const; + + CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp16_t to_fp16(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr fp16x2_t to_fp16x2(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr bf16_t to_bf16(float scale = 1.f) const; + CK_TILE_HOST_DEVICE constexpr bf16x2_t to_bf16x2(float scale = 1.f) const; + + CK_TILE_HOST_DEVICE constexpr operator float() const { return to_float(); } + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const { return to_fp32x2(); } + CK_TILE_HOST_DEVICE constexpr operator fp16_t() const { return to_fp16(); } + CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const { return to_fp16x2(); } + CK_TILE_HOST_DEVICE constexpr operator bf16_t() const { return to_bf16(); } + CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } template CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; @@ -191,131 +197,160 @@ CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f) } // namespace impl #endif -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const +CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return bf16_t{type_convert(convert_to_float(unpack(number<0>{})))}; + return bf16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const + +CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), - type_convert(convert_to_float(unpack(number<1>{})))}; + return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), + type_convert(convert_to_float(unpack(number<1>{}), scale))}; #endif } // TODO: make float_to_e2m1 generic so that we can convert from directrly. -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return convert_to_type(x); + return convert_to_type(x, scale); #endif } -CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); } -CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); } -CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale) +{ + return float_to_e2m1(x, scale); +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x)); + return float_to_e2m1(type_convert(x), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x)); + return float_to_e2m1(type_convert(x), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), - float_to_e2m1(type_convert(x[1]))); + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0]), scale), + float_to_e2m1(type_convert(x[1]), scale)); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), - float_to_e2m1(type_convert(x[1]))); + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0]), scale), + float_to_e2m1(type_convert(x[1]), scale)); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x) +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE - return impl::_to_f4(x); + return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1])); + return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); #endif } +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale) +{ + return x.to_fp32x2(scale); +} +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale) +{ + return x.to_fp16x2(scale); +} +CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale) +{ + return x.to_bf16x2(scale); +} +CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale) +{ + return x.to_float(scale); +} +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale) +{ + return x.to_fp16(scale); +} +CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale) +{ + return x.to_bf16(scale); +} + #if TEST_convert_with_table == 0 -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return convert_to_float(unpack(number<0>{})); + return convert_to_float(unpack(number<0>{}), scale); #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp32x2_t{convert_to_float(unpack(number<0>{})), - convert_to_float(unpack(number<1>{}))}; + return fp32x2_t{convert_to_float(unpack(number<0>{}), scale), + convert_to_float(unpack(number<1>{}), scale)}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const + +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp16_t{type_convert(convert_to_float(unpack(number<0>{})))}; + return fp16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; #endif } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { #if CK_TILE_FP4_CVT_DEVICE - return impl::_from_f4(data); + return impl::_from_f4(data, scale); #else - return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), - type_convert(convert_to_float(unpack(number<1>{})))}; + return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), + type_convert(convert_to_float(unpack(number<1>{}), scale))}; #endif } #else -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { - return e2m1_to_fp32_table[data & 0xf]; + return e2m1_to_fp32_table[unpack(number<0>{})] * scale; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]}; + return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale}; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const +CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { - return e2m1_to_fp16_table[data & 0xf]; + return type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale; } -CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { - return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]}; + return fp16x2_t{ + type_convert(type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)}; } #endif diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 94d6e3cd34..1455fce0ea 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -64,6 +64,7 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) +#undef CK_TILE_TYPE_CONVERT } // namespace ck_tile @@ -71,16 +72,36 @@ CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) namespace ck_tile { -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2) -CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2) -CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2) -CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16) -CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16) -#undef CK_TILE_TYPE_CONVERT +template +CK_TILE_HOST_DEVICE constexpr Y scaled_type_convert(X x, float scale); + +#define CK_TILE_SCALED_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ scaled_type_convert(stype_ x, \ + float scale) \ + { \ + return sname_##_to_##dname_(x, scale); \ + } \ + template <> \ + CK_TILE_HOST_DEVICE constexpr dtype_ type_convert(stype_ x) \ + { \ + return sname_##_to_##dname_(x, 1.f); \ + } + +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2) +CK_TILE_SCALED_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2) +CK_TILE_SCALED_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2) +CK_TILE_SCALED_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float) +CK_TILE_SCALED_TYPE_CONVERT(float, float, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16) +CK_TILE_SCALED_TYPE_CONVERT(bf16_t, bf16, pk_fp4_t, pk_fp4) +CK_TILE_SCALED_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16) +CK_TILE_SCALED_TYPE_CONVERT(fp16_t, fp16, pk_fp4_t, pk_fp4) +#undef CK_TILE_SCALED_TYPE_CONVERT + #endif } // namespace ck_tile diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index c3f1b7d221..b7329fcac7 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -409,7 +409,13 @@ struct HostTensor } // void SetZero() { ck_tile::ranges::fill(mData, 0); } - void SetZero() { std::fill(mData.begin(), mData.end(), 0); } + void SetZero() + { + if constexpr(std::is_same_v) + std::fill(mData.begin(), mData.end(), e8m0_t{1.f}); + else + std::fill(mData.begin(), mData.end(), 0); + } template void ForEach_impl(F&& f, std::vector& idx, size_t rank) diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index a9461dca9c..384fd3c1c4 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -3,6 +3,7 @@ if(GPU_TARGETS MATCHES "gfx9") endif() if(GPU_TARGETS MATCHES "gfx95") add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp) + add_gtest_executable(test_ck_tile_mx_scale test_mx_scale.cpp) endif() if(CK_USE_OCP_FP8 OR CK_USE_FNUZ_FP8) diff --git a/test/ck_tile/data_type/test_mx_scale.cpp b/test/ck_tile/data_type/test_mx_scale.cpp new file mode 100644 index 0000000000..7a024d238f --- /dev/null +++ b/test/ck_tile/data_type/test_mx_scale.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using ck_tile::bf16_t; +using ck_tile::bf16x2_t; +using ck_tile::fp16_t; +using ck_tile::fp16x2_t; +using ck_tile::fp32_t; +using ck_tile::fp32x2_t; +using ck_tile::number; +using ck_tile::pk_fp4_t; + +template +CK_TILE_HOST void test_convert(); + +using ck_tile::e8m0_raw_t; +using ck_tile::e8m0_t; + +TEST(OCP_Scale, NumericLimits) +{ + EXPECT_EQ(ck_tile::numeric::has_inf(), false); + EXPECT_EQ(ck_tile::numeric::zero(), ck_tile::numeric::signaling_NaN()); + EXPECT_EQ(ck_tile::numeric::min(), e8m0_t{e8m0_raw_t{0b00000000}}); + EXPECT_EQ(ck_tile::numeric::max(), e8m0_t{e8m0_raw_t{0b11111110}}); +} +TEST(OCP_Scale, NumericBasic) +{ + auto scale_1 = e8m0_t{1.0f}; + auto scale_2 = e8m0_t{e8m0_raw_t{ck_tile::numeric_traits::bias}}; // 2^0 + EXPECT_EQ(scale_1, scale_2); + + auto scale_3 = e8m0_t{8.0f}; + auto scale_4 = e8m0_t{e8m0_raw_t{3 + ck_tile::numeric_traits::bias}}; // 2^3 + EXPECT_EQ(scale_3, scale_4); +} + +TEST(OCP_Scale, ScaledConvertDevice) +{ + constexpr bool is_device = true; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} +TEST(OCP_Scale, ScaledConvertHost) +{ + constexpr bool is_device = false; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} +TEST(OCP_Scale, tensorInit) +{ + using scale_t = e8m0_t; + ck_tile::HostTensor scales({10, 10}); + ck_tile::FillUniformDistribution{1.f, 1.f}(scales); + scales.SetZero(); +} + +#define toPF4(x, y) ck_tile::scaled_type_convert(x, y) +#define toDST(x, y) ck_tile::scaled_type_convert(x, y) +#define toDSTx2(x, y) ck_tile::scaled_type_convert(x, y) + +#define toF32(x) ck_tile::type_convert(x) +#define toPF4_(x) ck_tile::type_convert(x) +#define toSRC(x) ck_tile::type_convert(x) +#define toDST_(x) ck_tile::type_convert(x) + +template +__global__ void MyKernel(Args... args) +{ + Kernel{}(args...); +} +template +struct SrcPkfp4Dst +{ + CK_TILE_HOST_DEVICE void + operator()(const SRC* src, DST* dst, e8m0_t scale1, e8m0_t scale2) const + { + + using SRCx2_t = ck_tile::ext_vector_t; + using DSTx2_t = ck_tile::ext_vector_t; + + ck_tile::static_for<0, N, 2>{}([&](auto i) { + const auto input2 = SRCx2_t{src[i], src[i + 1]}; + + if(i % 4 == 0) + { + // ex: fp32_t -> fp4 -> bf16_t + dst[i] = toDST(toPF4(src[i], scale1), scale2); + // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t + dst[i + 1] = toDST(toPF4_(toPF4(input2, scale1).unpack(number<1>{})), scale2); + } + else + { + // ex: fp32x2_t -> pk_fp4_t -> bf16x2_t + reinterpret_cast(dst)[i >> 1] = toDSTx2(toPF4(input2, scale1), scale2); + } + }); + } +}; + +template +CK_TILE_HOST void test_convert() +{ + const auto test_data = std::array{4.f, 6.f, 8.f, 10.f}; + const auto ref_data = std::array{8.f, 16.f, 16.f, 16.f}; + const auto scale1 = e8m0_t{8.0f}; + const auto scale2 = e8m0_t{16.0f}; + + static_assert(test_data.size() == ref_data.size()); + static_assert(test_data.size() % 2 == 0); + + constexpr int N = test_data.size(); + std::array in; + std::array ref, out; + + // prepare input and ground truth in host + for(int i = 0; i < N; ++i) + { + in[i] = toSRC(test_data[i]); + ref[i] = toDST_(ref_data[i]); + EXPECT_EQ(test_data[i], toF32(in[i])); + EXPECT_EQ(ref_data[i], toF32(ref[i])); + } + + using job = SrcPkfp4Dst; + + if constexpr(is_device) + { + auto in_d = std::make_unique(in.size() * sizeof(SRC)); + auto out_d = std::make_unique(out.size() * sizeof(DST)); + in_d->ToDevice(in.data()); + + MyKernel<<<1, 1>>>(reinterpret_cast(in_d->GetDeviceBuffer()), + reinterpret_cast(out_d->GetDeviceBuffer()), + scale1, + scale2); + + out_d->FromDevice(out.data()); + } + else + { + job{}(in.data(), out.data(), scale1, scale2); + } + + for(int i = 0; i < N; ++i) + EXPECT_EQ(ref[i], out[i]) << "i:" << i; +}