diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index f25b98f5a0..8b78990d08 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,7 +23,8 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f); +struct pk_float4_e2m1_t; +CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f); // TODO: Add stochastic method struct pk_float4_e2m1_t @@ -31,7 +32,7 @@ struct pk_float4_e2m1_t // TODO: Can we merge raw_type and type? using raw_type = uint8_t; using type = raw_type; - raw_type data; + type data; CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {} template >> @@ -39,12 +40,12 @@ struct pk_float4_e2m1_t { } CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f) - : data{float_to_e2m1(init, scale)} + : data{float_to_pk_fp4(init, scale)} { } CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr type get() const { return data; } CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const; @@ -61,8 +62,19 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } template - CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; - CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1) + CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const + { + return _unpack(number{}); + } + CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0, + const pk_float4_e2m1_t& x1) + { + return _pack(x0.get(), x1.get()); + } + + template + CK_TILE_HOST_DEVICE constexpr type _unpack(number) const; + CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1) { return (x1 << 4) | (x0 & 0b00001111); } @@ -92,7 +104,7 @@ struct pk_float4_e2m1_t }; using pk_fp4_t = pk_float4_e2m1_t; -using pk_fp4_raw_t = typename pk_fp4_t::raw_type; +using pk_fp4_raw_t = typename pk_fp4_t::type; template <> struct numeric_traits @@ -124,7 +136,7 @@ struct numeric CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; } - CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } // N/A @@ -136,7 +148,7 @@ struct numeric }; template -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number) const +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number) const { static_assert(I < 2, "Index is out of range."); if constexpr(I == 1) @@ -202,7 +214,7 @@ CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return bf16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } @@ -211,13 +223,13 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return bf16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } -// TODO: make float_to_e2m1 generic so that we can convert from directrly. -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) +// TODO: make it generic so that we can convert from directrly. +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); @@ -227,14 +239,20 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) } CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale) { - return float_to_e2m1(x, scale); +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x, scale); +#else + auto res = convert_to_type(x, scale); + return pk_fp4_t::_pack(res, res); +#endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale) @@ -242,7 +260,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale) @@ -250,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) @@ -258,7 +277,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) @@ -266,7 +285,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } @@ -301,7 +320,7 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return convert_to_float(unpack(number<0>{}), scale); + return convert_to_float(_unpack(number<0>{}), scale); #endif } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const @@ -309,8 +328,8 @@ CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp32x2_t{convert_to_float(unpack(number<0>{}), scale), - convert_to_float(unpack(number<1>{}), scale)}; + return fp32x2_t{convert_to_float(_unpack(number<0>{}), scale), + convert_to_float(_unpack(number<1>{}), scale)}; #endif } @@ -319,7 +338,7 @@ CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return fp16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const @@ -327,28 +346,29 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return fp16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { - return e2m1_to_fp32_table[unpack(number<0>{})] * scale; + return e2m1_to_fp32_table[_unpack(number<0>{})] * scale; } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { - return type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale; + return type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale; } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { return fp16x2_t{ - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale), - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)}; + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * + scale)}; } #endif diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index e03881a1c7..817a46a8ea 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -67,7 +67,10 @@ struct FillUniformDistribution : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - return ck_tile::type_convert(dis(gen)); + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); }); }; threads[it] = joinable_thread(thread_f); @@ -77,8 +80,12 @@ struct FillUniformDistribution { std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); - std::generate( - first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + std::generate(first, last, [&dis, &gen]() { + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); + }); } } diff --git a/test/ck_tile/data_type/test_pk_fp4.cpp b/test/ck_tile/data_type/test_pk_fp4.cpp index 15f027e95d..b1e981624a 100644 --- a/test/ck_tile/data_type/test_pk_fp4.cpp +++ b/test/ck_tile/data_type/test_pk_fp4.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include #include #include "ck_tile/core.hpp" @@ -29,6 +30,12 @@ TEST(PackedFp4, NumericLimits) EXPECT_EQ(ck_tile::numeric::epsilon(), pk_fp4_t{0b00010001}); EXPECT_EQ(ck_tile::numeric::round_error(), pk_fp4_t{0b00010001}); } +TEST(PackedFp4, fill) +{ + std::vector v_fp4(4); + ck_tile::FillUniformDistribution{1.f, 1.f}(v_fp4); + EXPECT_EQ(v_fp4[0].get(), pk_fp4_t{0b00100010}.get()); +} TEST(PackedFp4, ConvertBasic) { EXPECT_EQ(ck_tile::convert_to_type(0.0f), pk_fp4_t{0b00000000}.get()); @@ -102,7 +109,7 @@ struct SrcPkfp4Dst // ex: fp32_t -> fp4 -> bf16_t dst[i] = toDST(toPF4(src[i])); // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t - dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{}))); + dst[i + 1] = toDST(toPF4(input2).unpack(number<1>{})); } else {