From b5517fb522f0a8224eaccb68a6473a026e000820 Mon Sep 17 00:00:00 2001 From: Gino Lu Date: Mon, 14 Jul 2025 20:35:06 +0800 Subject: [PATCH] [CK_TILE] Add pk_fp4 data type (#2422) * [draft] Add pk_fp4 and test * Add hw conversion for fp4 * Refine test code and pk_fp4 constructor. * fix test indent * modify according to comment. * fix clang-format * modify according comments. --------- Co-authored-by: asleepzzz [ROCm/composable_kernel commit: 141bf2d54d78f8250fc1ad51ef8f2f54792d2a08] --- include/ck_tile/core.hpp | 2 + include/ck_tile/core/numeric/mxfp_convert.hpp | 213 ++++++++++++ include/ck_tile/core/numeric/numeric.hpp | 178 +++++----- include/ck_tile/core/numeric/pk_fp4.hpp | 324 ++++++++++++++++++ include/ck_tile/core/numeric/type_convert.hpp | 16 + test/ck_tile/data_type/CMakeLists.txt | 1 + test/ck_tile/data_type/test_pk_fp4.cpp | 162 +++++++++ 7 files changed, 806 insertions(+), 90 deletions(-) create mode 100644 include/ck_tile/core/numeric/mxfp_convert.hpp create mode 100644 include/ck_tile/core/numeric/pk_fp4.hpp create mode 100644 test/ck_tile/data_type/test_pk_fp4.cpp diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index ed39719cf4..10dfdd7d28 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -33,8 +33,10 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/integral_constant.hpp" #include "ck_tile/core/numeric/math.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" #include "ck_tile/core/numeric/null_type.hpp" #include "ck_tile/core/numeric/numeric.hpp" +#include "ck_tile/core/numeric/pk_fp4.hpp" #include "ck_tile/core/numeric/pk_int4.hpp" #include "ck_tile/core/numeric/type_convert.hpp" #include "ck_tile/core/numeric/vector_type.hpp" diff --git a/include/ck_tile/core/numeric/mxfp_convert.hpp b/include/ck_tile/core/numeric/mxfp_convert.hpp new file mode 100644 index 0000000000..b2e138e880 --- /dev/null +++ b/include/ck_tile/core/numeric/mxfp_convert.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { +// modify from include/ck/utility/mxfp_utils.hpp + +template +struct numeric_utils : numeric_traits +{ + + using traits = numeric_traits; + using _numeric = numeric; + using raw_type = typename T::raw_type; + + static constexpr int exp_mask = (1 << traits::exp) - 1; + + static constexpr int get_exponent(raw_type x) + { + // TODO: check if repeated calls are optimized. + return (x >> traits::mant) & exp_mask; + } + static constexpr bool is_positive(raw_type x) + { + return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero; + } + static constexpr bool is_subnormal(raw_type x) + { + return get_exponent(x) == _numeric::binary_zero; + } + // TODO: replace double with template arg? + static constexpr double get_mantissa(raw_type x) + { + double mantissa = is_subnormal(x) ? 0.0f : 1.0f; + for(uint32_t i = 0; i < traits::mant; ++i) + { + mantissa += std::ldexp(static_cast(x & 0b1), -(traits::mant - i)); + x >>= 1; + } + return mantissa; + } +}; + +template +CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127) +{ + using utils = numeric_utils; + static constexpr int e8m0_bias = 127; // TODO: make it generic. + float sign = utils::is_positive(data) ? 1.0 : -1.0; + int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias; + float mant = utils::get_mantissa(data); + + return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias); +} + +template +CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value) +{ + using bitwise_type = typename numeric_traits::bitwise_type; + + if(std::abs(value) > float(numeric::max())) + { + float max_value = numeric::max(); + + // cppcheck-suppress redundantAssignment + uint32_t max_bitwise = bit_cast(max_value); + + // cppcheck-suppress redundantAssignment + bitwise_type sign = + bit_cast(value) >> (numeric_traits::exp + numeric_traits::mant); + bitwise_type exp = + ((max_bitwise >> numeric_traits::mant) & numeric_traits::exp_mask) - + (numeric_traits::bias - numeric_traits::bias); + bitwise_type mantissa = + max_bitwise >> (numeric_traits::mant - numeric_traits::mant); + + uint32_t mant_prev = max_bitwise >> (numeric_traits::mant - numeric_traits::mant); + mant_prev &= ((1 << numeric_traits::mant) - 1); + mant_prev--; + + mant_prev <<= (numeric_traits::mant - numeric_traits::mant); + uint32_t prev_bit = + ((max_bitwise >> numeric_traits::mant) << numeric_traits::mant) | + mant_prev; + + float prev_val = bit_cast(prev_bit); + float diff = max_value - prev_val; + + float actual_max = max_value + (diff / 2); + + if(std::abs(value) < actual_max) + { + return sign << ((numeric_traits::exp + numeric_traits::mant)) | + (exp << numeric_traits::mant) | mantissa; + } + else + { + if constexpr(!numeric::has_inf()) + { + + return (1 << (numeric_traits::mant + numeric_traits::exp)) - 1; + } + else + { + exp++; + return sign << ((numeric_traits::exp + numeric_traits::mant)) | + (exp << numeric_traits::mant); + } + } + } + const int mfmt = numeric_traits::mant; + uint32_t x; + x = bit_cast(value); + + uint32_t head, mantissa; + int32_t exponent, bias; + uint32_t sign; + + head = x & numeric_traits::head_mask; + mantissa = x & numeric_traits::mant_mask; + exponent = (head >> numeric_traits::mant) & numeric_traits::exp_mask; + sign = head >> (numeric_traits::mant + numeric_traits::exp); + bias = numeric_traits::bias; + + if(x == 0) + { + return 0b0; + } + + const int mini_bias = numeric_traits::bias; + const int mini_denormal_act_exponent = 1 - mini_bias; + + int act_exponent, out_exponent, exponent_diff; + + bool is_subnorm = false; + + if(exponent == 0) + { + act_exponent = exponent - bias + 1; + exponent_diff = mini_denormal_act_exponent - act_exponent; + is_subnorm = true; + } + else + { + act_exponent = exponent - bias; + if(act_exponent <= mini_denormal_act_exponent) + { + exponent_diff = mini_denormal_act_exponent - act_exponent; + is_subnorm = true; + } + else + { + exponent_diff = 0; + } + mantissa += (1UL << mfmt); + } + + auto shift_amount = (mfmt - numeric_traits::mant + exponent_diff); + shift_amount = (shift_amount >= 64) ? 63 : shift_amount; + bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1)); + + float min_subnorm = float(numeric::epsilon()) * (sign ? -1 : 1); + + if(is_subnorm && std::abs(value) < std::abs(min_subnorm)) + { + // closer to 0 + if(std::abs(value) <= std::abs(min_subnorm - value)) + return sign << (numeric_traits::exp + numeric_traits::mant); + else + return 1 | (sign << (numeric_traits::exp + numeric_traits::mant)); + } + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << mfmt); + out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1); + + uint32_t drop_mask = (1UL << (mfmt - numeric_traits::mant)) - 1; + bool odd = mantissa & (1UL << (mfmt - numeric_traits::mant)); + mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask; + + if(out_exponent == 0) + { + if((1UL << mfmt) & mantissa) + { + out_exponent = 1; + } + } + else + { + if((1UL << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + } + } + + mantissa >>= (mfmt - numeric_traits::mant); + + if(out_exponent == 0 && mantissa == 0) + { + return sign << (numeric_traits::exp + numeric_traits::mant); + } + + mantissa &= (1UL << numeric_traits::mant) - 1; + return (sign << (numeric_traits::exp + numeric_traits::mant)) | + (out_exponent << numeric_traits::mant) | mantissa; +} + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/numeric.hpp b/include/ck_tile/core/numeric/numeric.hpp index f125fbf2ce..6b61e3f99c 100644 --- a/include/ck_tile/core/numeric/numeric.hpp +++ b/include/ck_tile/core/numeric/numeric.hpp @@ -103,94 +103,92 @@ struct numeric_traits } // namespace ck_tile -#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \ - attr_ bool operator==(const type_& x, const type_& y) \ - { \ - return static_cast(x) == static_cast(y); \ - } \ - attr_ bool operator!=(const type_& x, const type_& y) \ - { \ - return static_cast(x) != static_cast(y); \ - } \ - attr_ bool operator<(const type_& x, const type_& y) \ - { \ - return static_cast(x) < static_cast(y); \ - } \ - attr_ bool operator<=(const type_& x, const type_& y) \ - { \ - return static_cast(x) <= static_cast(y); \ - } \ - attr_ bool operator>(const type_& x, const type_& y) \ - { \ - return static_cast(x) > static_cast(y); \ - } \ - attr_ bool operator>=(const type_& x, const type_& y) \ - { \ - return static_cast(x) >= static_cast(y); \ - } \ - attr_ type_ operator+(const type_& x, const type_& y) \ - { \ - return type_(static_cast(x) + static_cast(y)); \ - } \ - attr_ type_ operator-(const type_& x) \ - { \ - constexpr uint32_t bits = sizeof(type_) * 8; \ - constexpr uint32_t mask = 1 << (bits - 1); \ - type_ y = x; \ - y.data ^= static_cast(mask); \ - return y; \ - } \ - attr_ type_ operator-(const type_& x, const type_& y) \ - { \ - return type_(static_cast(x) - static_cast(y)); \ - } \ - attr_ type_ operator*(const type_& x, const type_& y) \ - { \ - return type_(static_cast(x) * static_cast(y)); \ - } \ - attr_ type_ operator/(const type_& x, const type_& y) \ - { \ - return type_(static_cast(x) / static_cast(y)); \ - } \ - attr_ type_& operator+=(type_& x, const type_& y) \ - { \ - x = type_(static_cast(x) + static_cast(y)); \ - return x; \ - } \ - attr_ type_& operator-=(type_& x, const type_& y) \ - { \ - x = type_(static_cast(x) - static_cast(y)); \ - return x; \ - } \ - attr_ type_& operator*=(type_& x, const type_& y) \ - { \ - x = type_(static_cast(x) * static_cast(y)); \ - return x; \ - } \ - attr_ type_& operator/=(type_& x, const type_& y) \ - { \ - x = type_(static_cast(x) / static_cast(y)); \ - return x; \ - } \ - attr_ type_& operator++(type_& x) \ - { \ - x = type_(static_cast(x) + 1.f); \ - return x; \ - } \ - attr_ type_& operator--(type_& x) \ - { \ - x = type_(static_cast(x) - 1.f); \ - return x; \ - } \ - attr_ type_ operator++(type_& x, int) \ - { \ - type_ y(x); \ - x = type_(static_cast(x) + 1.f); \ - return y; \ - } \ - attr_ type_ operator--(type_& x, int) \ - { \ - type_ y(x); \ - x = type_(static_cast(x) - 1.f); \ - return y; \ +#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \ + attr_ bool operator==(const type_& x, const type_& y) \ + { \ + return std::abs(static_cast(x) - static_cast(y)) < \ + static_cast(numeric::epsilon()); \ + } \ + attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \ + attr_ bool operator<(const type_& x, const type_& y) \ + { \ + return static_cast(x) < static_cast(y); \ + } \ + attr_ bool operator<=(const type_& x, const type_& y) \ + { \ + return static_cast(x) <= static_cast(y); \ + } \ + attr_ bool operator>(const type_& x, const type_& y) \ + { \ + return static_cast(x) > static_cast(y); \ + } \ + attr_ bool operator>=(const type_& x, const type_& y) \ + { \ + return static_cast(x) >= static_cast(y); \ + } \ + attr_ type_ operator+(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) + static_cast(y)); \ + } \ + attr_ type_ operator-(const type_& x) \ + { \ + constexpr uint32_t bits = sizeof(type_) * 8; \ + constexpr uint32_t mask = 1 << (bits - 1); \ + type_ y = x; \ + y.data ^= static_cast(mask); \ + return y; \ + } \ + attr_ type_ operator-(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) - static_cast(y)); \ + } \ + attr_ type_ operator*(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) * static_cast(y)); \ + } \ + attr_ type_ operator/(const type_& x, const type_& y) \ + { \ + return type_(static_cast(x) / static_cast(y)); \ + } \ + attr_ type_& operator+=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) + static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator-=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) - static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator*=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) * static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator/=(type_& x, const type_& y) \ + { \ + x = type_(static_cast(x) / static_cast(y)); \ + return x; \ + } \ + attr_ type_& operator++(type_& x) \ + { \ + x = type_(static_cast(x) + 1.f); \ + return x; \ + } \ + attr_ type_& operator--(type_& x) \ + { \ + x = type_(static_cast(x) - 1.f); \ + return x; \ + } \ + attr_ type_ operator++(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) + 1.f); \ + return y; \ + } \ + attr_ type_ operator--(type_& x, int) \ + { \ + type_ y(x); \ + x = type_(static_cast(x) - 1.f); \ + return y; \ } diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp new file mode 100644 index 0000000000..b7dca9dd0a --- /dev/null +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -0,0 +1,324 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck_tile/core/config.hpp" +#include "ck_tile/core/numeric/half.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" + +#if defined(__gfx950__) +#define CK_TILE_FP4_CVT_DEVICE 1 +#else +#define CK_TILE_FP4_CVT_DEVICE 0 +#endif + +#define TEST_convert_with_table 0 + +namespace ck_tile { + +using fp32_t = float; +using fp32x2_t = float __attribute__((ext_vector_type(2))); +using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2))); + +CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float); + +// TODO: Add stochastic method +struct pk_float4_e2m1_t +{ + static constexpr int exponent = 2; + static constexpr int mantissa = 1; + static constexpr int bias = 1; + // TODO: Can we merge raw_type and type? + using raw_type = uint8_t; + using type = raw_type; + raw_type data; + + CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {} + template >> + CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast(init)} + { + } + CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)} + { + } + CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr operator float() const; + CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const; + CK_TILE_HOST_DEVICE constexpr operator fp16_t() const; + CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const; + CK_TILE_HOST_DEVICE constexpr operator bf16_t() const; + CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const; + + template + CK_TILE_HOST_DEVICE raw_type unpack(number) const; + CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } + +#if TEST_convert_with_table + static constexpr float e2m1_to_fp32_table[16] = { + 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6}; + static constexpr fp16_t e2m1_to_fp16_table[16] = { + bit_cast(static_cast(0x0000)), // 0 + bit_cast(static_cast(0x3800)), // 0.5 + bit_cast(static_cast(0x3C00)), // 1 + bit_cast(static_cast(0x3E00)), // 1.5 + bit_cast(static_cast(0x4000)), // 2 + bit_cast(static_cast(0x4200)), // 3 + bit_cast(static_cast(0x4400)), // 4 + bit_cast(static_cast(0x4600)), // 6 + bit_cast(static_cast(0x8000)), // -0 + bit_cast(static_cast(0xB800)), // -0.5 + bit_cast(static_cast(0xBC00)), // -1 + bit_cast(static_cast(0xBE00)), // -1.5 + bit_cast(static_cast(0xC000)), // -2 + bit_cast(static_cast(0xC200)), // -3 + bit_cast(static_cast(0xC400)), // -4 + bit_cast(static_cast(0xC600)) // -6 + }; +#endif +}; + +using pk_fp4_t = pk_float4_e2m1_t; +using pk_fp4_raw_t = typename pk_fp4_t::raw_type; + +template <> +struct numeric_traits +{ + using bitwise_type = pk_fp4_raw_t; + + static constexpr int exp = 2; + static constexpr int mant = 1; + static constexpr int bias = 1; + static constexpr int PackedSize = 2; +}; + +// limits +template +struct numeric; + +template <> +struct numeric +{ + static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1 + static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6 + static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6 + static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5 + static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5 + static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0 + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; } + CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; } + + CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } + // N/A + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); } + // N/A + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); } + // N/A + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); } +}; + +template +CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number) const +{ + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 1) + return (data >> 4); + else + return data & 0b00001111; +} +CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t) +// TODO: consider replace this macro to improve performance + +#if CK_TILE_FP4_CVT_DEVICE +namespace impl { + +template +CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f) +{ + // TODO: check the order + if constexpr(std::is_same_v) + return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0]; + else if constexpr(std::is_same_v) + return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0); + else if constexpr(std::is_same_v) + return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0]; + else if constexpr(std::is_same_v) + return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0); + else if constexpr(std::is_same_v) + return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0]; + else if constexpr(std::is_same_v) + return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0); + else + static_assert(std::false_type::value, "Unsupported type."); + return T{}; +} +template +CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f) +{ + // TODO: check the order + union + { + uint32_t u32; + pk_fp4_raw_t pf4[4]; + } cvt{0}; + if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0); + else if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0); + else if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0); + else if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0); + else if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0); + else if constexpr(std::is_same_v) + cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0); + else + static_assert(std::false_type::value, "Unsupported type."); + return cvt.pf4[0]; +} + +} // namespace impl +#endif + +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return bf16_t{type_convert(convert_to_float(unpack(number<0>{})))}; +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), + type_convert(convert_to_float(unpack(number<1>{})))}; +#endif +} + +// TODO: make float_to_e2m1 generic so that we can convert from directrly. +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return convert_to_type(x); +#endif +} +CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); } +CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); } +CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); } +CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); } +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return float_to_e2m1(type_convert(x)); +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return float_to_e2m1(type_convert(x)); +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), + float_to_e2m1(type_convert(x[1]))); +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return pk_fp4_t::pack(float_to_e2m1(type_convert(x[0])), + float_to_e2m1(type_convert(x[1]))); +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x) +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x); +#else + return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1])); +#endif +} + +#if TEST_convert_with_table == 0 +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return convert_to_float(unpack(number<0>{})); +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return fp32x2_t{convert_to_float(unpack(number<0>{})), + convert_to_float(unpack(number<1>{}))}; +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return fp16_t{type_convert(convert_to_float(unpack(number<0>{})))}; +#endif +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +{ +#if CK_TILE_FP4_CVT_DEVICE + return impl::_from_f4(data); +#else + return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}))), + type_convert(convert_to_float(unpack(number<1>{})))}; +#endif +} +#else +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const +{ + return e2m1_to_fp32_table[data & 0xf]; +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const +{ + return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]}; +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const +{ + return e2m1_to_fp16_table[data & 0xf]; +} +CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const +{ + return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]}; +} +#endif + +} // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 4011e08ce4..94d6e3cd34 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -11,6 +11,7 @@ #include "ck_tile/core/numeric/bfloat16.hpp" #include "ck_tile/core/numeric/float8.hpp" #include "ck_tile/core/numeric/int8.hpp" +#include "ck_tile/core/numeric/mxfp_convert.hpp" namespace ck_tile { @@ -64,6 +65,21 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) +} // namespace ck_tile + +#include "ck_tile/core/numeric/pk_fp4.hpp" + +namespace ck_tile { + +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2) +CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4) +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2) +CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4) +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2) +CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4) +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float) +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16) +CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16) #undef CK_TILE_TYPE_CONVERT #endif diff --git a/test/ck_tile/data_type/CMakeLists.txt b/test/ck_tile/data_type/CMakeLists.txt index e489f306f7..655a0cef9c 100644 --- a/test/ck_tile/data_type/CMakeLists.txt +++ b/test/ck_tile/data_type/CMakeLists.txt @@ -1,4 +1,5 @@ # Currently ck_tile is only built on gfx9 if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_ck_tile_pk_int4 test_pk_int4.cpp) + add_gtest_executable(test_ck_tile_pk_fp4 test_pk_fp4.cpp) endif() diff --git a/test/ck_tile/data_type/test_pk_fp4.cpp b/test/ck_tile/data_type/test_pk_fp4.cpp new file mode 100644 index 0000000000..15f027e95d --- /dev/null +++ b/test/ck_tile/data_type/test_pk_fp4.cpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" + +using ck_tile::bf16_t; +using ck_tile::bf16x2_t; +using ck_tile::fp16_t; +using ck_tile::fp16x2_t; +using ck_tile::fp32_t; +using ck_tile::fp32x2_t; +using ck_tile::number; +using ck_tile::pk_fp4_t; + +template +CK_TILE_HOST void test_convert(); + +TEST(PackedFp4, NumericLimits) +{ + EXPECT_EQ(ck_tile::numeric::has_inf(), false); + EXPECT_EQ(ck_tile::numeric::zero(), pk_fp4_t{0b00000000}); + EXPECT_EQ(ck_tile::numeric::min(), pk_fp4_t{0b00100010}); + EXPECT_EQ(ck_tile::numeric::max(), pk_fp4_t{0b01110111}); + EXPECT_EQ(ck_tile::numeric::lowest(), pk_fp4_t{0b11111111}); + EXPECT_EQ(ck_tile::numeric::epsilon(), pk_fp4_t{0b00010001}); + EXPECT_EQ(ck_tile::numeric::round_error(), pk_fp4_t{0b00010001}); +} +TEST(PackedFp4, ConvertBasic) +{ + EXPECT_EQ(ck_tile::convert_to_type(0.0f), pk_fp4_t{0b00000000}.get()); + EXPECT_EQ(ck_tile::convert_to_type(-0.0f), pk_fp4_t{0b00001000}.get()); + EXPECT_EQ(ck_tile::convert_to_type(-1.0f), pk_fp4_t{0b00001010}.get()); + EXPECT_EQ(ck_tile::type_convert(0.0f), pk_fp4_t{0b00000000}); + EXPECT_EQ(ck_tile::type_convert(-0.0f), pk_fp4_t{0b00001000}); + EXPECT_EQ(ck_tile::type_convert(-1.0f), pk_fp4_t{0b00001010}); + EXPECT_EQ(pk_fp4_t(0.0f), pk_fp4_t{0b00000000}); + EXPECT_EQ(pk_fp4_t(-0.0f), pk_fp4_t{0b00001000}); + EXPECT_EQ(pk_fp4_t(-1.0f), pk_fp4_t{0b00001010}); + EXPECT_EQ(pk_fp4_t{0.0f}, pk_fp4_t{0b00000000}); + EXPECT_EQ(pk_fp4_t{-0.0f}, pk_fp4_t{0b00001000}); + EXPECT_EQ(pk_fp4_t{-1.0f}, pk_fp4_t{0b00001010}); +} +TEST(PackedFp4, NumericBasic) +{ + auto f1 = pk_fp4_t{1.5f}; + auto f2 = pk_fp4_t{3.0f}; + auto ref = pk_fp4_t{-1.5f}; + EXPECT_EQ(f1 - f2, ref); +} +TEST(PackedFp4, ConvertDevice) +{ + constexpr bool is_device = true; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} +TEST(PackedFp4, ConvertHost) +{ + constexpr bool is_device = false; + test_convert(); // fp32 -> fp4 -> fp32 + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); + test_convert(); +} + +#define toF32(x) ck_tile::type_convert(x) +#define toPF4(x) ck_tile::type_convert(x) +#define toSRC(x) ck_tile::type_convert(x) +#define toDST(x) ck_tile::type_convert(x) +#define toDSTx2(x) ck_tile::type_convert(x) + +template +__global__ void MyKernel(Args... args) +{ + Kernel{}(args...); +} +template +struct SrcPkfp4Dst +{ + CK_TILE_HOST_DEVICE void operator()(const SRC* src, DST* dst) const + { + + using SRCx2_t = ck_tile::ext_vector_t; + using DSTx2_t = ck_tile::ext_vector_t; + + ck_tile::static_for<0, N, 2>{}([&](auto i) { + const auto input2 = SRCx2_t{src[i], src[i + 1]}; + + if(i % 4 == 0) + { + // ex: fp32_t -> fp4 -> bf16_t + dst[i] = toDST(toPF4(src[i])); + // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t + dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{}))); + } + else + { + // ex: fp32x2_t -> pk_fp4_t -> bf16x2_t + reinterpret_cast(dst)[i >> 1] = toDSTx2(toPF4(input2)); + } + }); + } +}; + +template +CK_TILE_HOST void test_convert() +{ + const auto test_data = std::array{0.f, 0.25f, 0.5f, 0.75f, 1.f, 1.25f, 1.5f, 1.75f, + -0.f, -0.25f, -0.5f, -0.75f, -1.f, -1.25f, -1.5f, -1.75f, + 2.f, 2.5f, 3.f, 3.5f, 4.f, 5.f, 5.0625f, 6.f}; + const auto ref_data = + std::array{0.f, 0.f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 2.f, -0.f, -0.f, -0.5f, -1.f, + -1.f, -1.f, -1.5f, -2.f, 2.f, 2.f, 3.f, 4.f, 4.f, 4.f, 6.f, 6.f}; + + static_assert(test_data.size() == ref_data.size()); + static_assert(test_data.size() % 2 == 0); + + constexpr int N = test_data.size(); + std::array in; + std::array ref, out; + + // prepare input and ground truth in host + for(int i = 0; i < N; ++i) + { + in[i] = toSRC(test_data[i]); + ref[i] = toDST(ref_data[i]); + EXPECT_EQ(test_data[i], toF32(in[i])); + EXPECT_EQ(ref_data[i], toF32(ref[i])); + } + + using job = SrcPkfp4Dst; + + if constexpr(is_device) + { + auto in_d = std::make_unique(in.size() * sizeof(SRC)); + auto out_d = std::make_unique(out.size() * sizeof(DST)); + in_d->ToDevice(in.data()); + + MyKernel<<<1, 1>>>(reinterpret_cast(in_d->GetDeviceBuffer()), + reinterpret_cast(out_d->GetDeviceBuffer())); + + out_d->FromDevice(out.data()); + } + else + { + job{}(in.data(), out.data()); + } + + for(int i = 0; i < N; ++i) + EXPECT_EQ(ref[i], out[i]) << "i:" << i; +}